<a href="https://colab.research.google.com/github/ahzaidy/Programs/blob/main/CPSC_5440_HW21.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, models
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pickle
from google.colab import drive

# Mount Google Drive
drive.mount("/content/drive")

# Load CIFAR-100 dataset from Google Drive
with open('/content/drive/My Drive/train', 'rb') as file:
    train_dict = pickle.load(file, encoding='bytes')

with open('/content/drive/My Drive/test', 'rb') as file:
    test_dict = pickle.load(file, encoding='bytes')

# Extract data and labels
train_data = torch.tensor(train_dict[b'data'], dtype=torch.float32).reshape(-1, 3, 32, 32) / 255.0
train_labels = torch.tensor(train_dict[b'fine_labels'], dtype=torch.long)
test_data = torch.tensor(test_dict[b'data'], dtype=torch.float32).reshape(-1, 3, 32, 32) / 255.0
test_labels = torch.tensor(test_dict[b'fine_labels'], dtype=torch.long)

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

'''
train_data = torch.stack([transform(image) for image in train_data])
test_data = torch.stack([transform(image) for image in test_data])'''



class CIFAR100Dataset(torch.utils.data.Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.data[idx]
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx]

# Define datasets with transformations applied dynamically
train_dataset = CIFAR100Dataset(train_data, train_labels, transform=transform)
test_dataset = CIFAR100Dataset(test_data, test_labels, transform=transform)


# Define the model with configurable number of classes
class SimpleResNet(nn.Module):
    def __init__(self, num_classes=100):  # Default to 100 classes
        super(SimpleResNet, self).__init__()
        resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.features = nn.Sequential(
            resnet18.conv1, resnet18.bn1, resnet18.relu, resnet18.maxpool,
            resnet18.layer1, resnet18.layer2
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(resnet18.layer2[-1].bn2.num_features, num_classes)  # Ensure correct number of output classes

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# Initialize model
device = torch.device("cpu")
model = SimpleResNet().to(device)

# Set hyperparameters
config = {
    'batch_size': 16,
    'lr': 0.1,
    'epochs': 50,
    'weight_decay': 1e-4
}

model = torch.compile(model)

train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=2, pin_memory=True)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config['lr'], momentum=0.9, weight_decay=5e-4)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5)


# Updated training function with dynamic number of classes
def train_cifar100(model, train_loader, test_loader, optimizer, criterion, scheduler, device, config, epochs=50):
    best_accuracy = 0.0
    best_config = {}
    history = {'epoch': [], 'accuracy': []}

    torch.backends.mkldnn.enabled = True  # Enable MKL-DNN for faster CPU training

    for epoch in range(epochs):
        model.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

        # Validation (compute only accuracy for efficiency)
        correct, total = 0, 0
        model.eval()
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = correct / total
        history['epoch'].append(epoch + 1)
        history['accuracy'].append(val_accuracy)

        print(f'Epoch {epoch+1}, Accuracy: {val_accuracy:.4f}')
        scheduler.step(val_accuracy)  # Dynamic learning rate adjustment

        # Track best accuracy and best hyperparameters
        if val_accuracy > best_accuracy:
            best_accuracy = val_accuracy
            best_config = {
                'batch_size': config['batch_size'],
                'lr': config['lr'],
                'epochs': epoch + 1,  # Save up to this epoch
                'weight_decay': config['weight_decay']
            }

    return history, best_accuracy, best_config


# Train model
history, best_accuracy, best_config = train_cifar100(
    model, train_loader, test_loader, optimizer, criterion, scheduler, device, config, epochs=config['epochs']
)

# Plot results
df = pd.DataFrame(history)
plt.figure(figsize=(10, 5))
sns.lineplot(data=df, x='epoch', y='accuracy')
plt.xlabel('Epoch')
plt.ylabel('Test Accuracy')
plt.title('CIFAR-100 Training Results')
plt.grid()
plt.savefig("/content/drive/My Drive/cifar100_training_plot.png", dpi=300)
plt.show()

# Display best accuracy and corresponding hyperparameters
print(f'Best Accuracy: {best_accuracy:.4f}')
print('Best Hyperparameters:')
for key, value in best_config.items():
    print(f'  {key}: {value}')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


KeyboardInterrupt: 

In [None]:
!pip install ray