In [13]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [14]:
import torch
from torch import nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from torch.utils.data import DataLoader, random_split


In [15]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [16]:
import numpy as np
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import datasets, transforms
import torch

# Load the CIFAR-10 training data
train_data = datasets.CIFAR10(root='./data', train=True, download=True)

# Calculate mean and std
# as per Github discussions (paulkorir, 2018): https://github.com/facebookarchive/fb.resnet.torch/issues/180#issuecomment-433419706
x = np.concatenate([np.asarray(train_data[i][0]) for i in range(len(train_data))])
train_mean = np.mean(x, axis=(0, 1)) / 255.0
train_std = np.std(x, axis=(0, 1)) / 255.0
print("Mean:", train_mean)
print("Std:", train_std)

# Define transformations
# Experiment with/tune augmentations here
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, hue=0.1),
    transforms.Normalize(mean=train_mean.tolist(), std=train_std.tolist())
])

val_test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=train_mean.tolist(), std=train_std.tolist())
])

class TransformDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# Load raw dataset without transforms first
train_dataset_raw = datasets.CIFAR10(root='./data', train=True, transform=None, download=True)

# Split raw dataset
val_split = 0.2
train_size = int((1 - val_split) * len(train_dataset_raw))
val_size = len(train_dataset_raw) - train_size
train_subset_raw, val_subset_raw = random_split(train_dataset_raw, [train_size, val_size])

# Create separate datasets with appropriate transforms
train_data = TransformDataset(train_subset_raw, train_transform)
val_data = TransformDataset(val_subset_raw, val_test_transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=val_test_transform, download=True)

# Create DataLoaders
batch_size = 256
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Print dataset sizes
print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")
print(f"Test samples: {len(test_dataset)}")

Files already downloaded and verified
Mean: [0.49139968 0.48215841 0.44653091]
Std: [0.24703223 0.24348513 0.26158784]
Files already downloaded and verified
Files already downloaded and verified
Training samples: 40000
Validation samples: 10000
Test samples: 10000


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

class VGG_Network(nn.Module):
    def __init__(self, input_size, num_classes, config='vgg16'):
        super(VGG_Network, self).__init__()
        conv_blocks = []

        if config == 'vgg11':
            print("Constructing VGG11")
            self.conv2d_block1 = self.conv2d_block(input_size[0], 64, 1)
            self.conv2d_block2 = self.conv2d_block(64, 128, 1)
            self.conv2d_block3 = self.conv2d_block(128, 256, 2)
            self.conv2d_block4 = self.conv2d_block(256, 512, 2)
            self.conv2d_block5 = self.conv2d_block(512, 512, 2)

        elif config == 'vgg16':
            self.conv2d_block1 = self.conv2d_block(input_size[0], 64, 2)
            self.conv2d_block2 = self.conv2d_block(64, 128, 2)
            self.conv2d_block3 = self.conv2d_block(128, 256, 3)
            self.conv2d_block4 = self.conv2d_block(256, 512, 3)
            self.conv2d_block5 = self.conv2d_block(512, 512, 3)

        self.linear1 = nn.Linear(512 * 7 * 7, 4096)
        self.relu1 = nn.ReLU(inplace=True)
        self.dropout1 = nn.Dropout(0.5)
        self.linear2 = nn.Linear(4096, 4096)
        self.relu2 = nn.ReLU(inplace=True)
        self.dropout2 = nn.Dropout(0.5)
        self.linear3 = nn.Linear(4096, num_classes)

    def conv2d_block(self, in_channels, out_channels, num_layers):
        layers = []
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        layers.append(nn.ReLU(inplace=True))

        for _ in range(num_layers-1):
            layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))
            in_channels = out_channels

        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv2d_block1(x)
        x = self.conv2d_block2(x)
        x = self.conv2d_block3(x)
        x = self.conv2d_block4(x)
        x = self.conv2d_block5(x)
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.linear2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
        x = self.linear3(x)
        return nn.functional.log_softmax(x, dim=1) #output

# Function to train the model
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, scheduler, patience=5):
    best_val_loss = float('inf')
    epochs_without_improvement = 0

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        print(f'Epoch [{epoch+1}/{num_epochs}]')
        model.train()

        train_running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_running_loss += loss.item()

        model.eval()
        val_running_loss = 0.0
        with torch.no_grad():  # Disable gradient calculation
            for i, (inputs, labels) in enumerate(val_loader):
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_running_loss += loss.item()

        average_train_loss = train_running_loss / len(train_loader)
        average_val_loss = val_running_loss / len(val_loader)
        scheduler.step(average_val_loss)

        # Store losses for plotting
        train_losses.append(average_train_loss)
        val_losses.append(average_val_loss)

        print(f"Average Train Loss: {average_train_loss:.4f}")
        print(f"Average Validation Loss: {average_val_loss:.4f}")

        # Check if validation loss improved
        if average_val_loss < best_val_loss:
            best_val_loss = average_val_loss
            epochs_without_improvement = 0
            # Save best model
            torch.save(model.state_dict(), '/content/drive/MyDrive/best_model.pth')
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f'Early stopping at epoch {epoch+1}')
                break

    # Save the average losses to a text file
    with open('/content/drive/MyDrive/losses.txt', 'w') as f:
        f.write("Average Train Losses:\n")
        f.writelines([f"{loss}\n" for loss in train_losses])
        f.write("\nAverage Validation Losses:\n")
        f.writelines([f"{loss}\n" for loss in val_losses])

    return train_losses, val_losses

# Example usage
def main():
    # Model parameters
    input_size = (3, 224, 224)
    num_classes = 10
    learning_rate = 1e-2
    num_epochs = 1000 # we use early stopping, so it's not really 100 >.<
    weight_decay = 5 * 1e-4
    momentum = 0.9

    # Initialize the model
    model = VGG_Network(input_size, num_classes, config='vgg11').to(device)

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), momentum=momentum, lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.1, verbose=True)

    # Train the model and capture losses
    train_losses, val_losses = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, scheduler)

    # Visualize train and validation loss
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Train and Validation Loss')
    plt.savefig('train_val_loss.png')
    plt.show()

    drive.flush_and_unmount() #important for saved files to reflect quickly on Google Drive if using Colab
if __name__ == '__main__':
    main()


Constructing VGG11
Epoch [1/1000]


In [None]:
pip install graphviz


In [None]:
! pip install torchview
! pip install git+https://github.com/mert-kurttutan/torchview.git


In [None]:
from torchview import draw_graph
input_size = (3, 224, 224)
num_classes=10
model = VGG_Network(input_size, num_classes, config='vgg11').to(device)
# device='meta' -> no memory is consumed for visualization
model_graph = draw_graph(model, input_size=(1, 3, 224, 224), device=device)
model_graph.visual_graph

In [None]:
from torchview import draw_graph
model = VGG_Network(input_size, num_classes, config='vgg16').to(device)
# device='meta' -> no memory is consumed for visualization
model_graph = draw_graph(model, input_size=(1, 3, 224, 224), device=device)
model_graph.visual_graph