In [7]:

import torch.nn as nn
import os
import torch.nn.functional as F
from collections import Counter
from torchvision import transforms, datasets
import torchvision
import torchmetrics
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
import time
import logging

In [1]:
BATCH_SIZE = 256
NUM_EPOCHS = 5
LEARNING_RATE = 0.01
NUM_WORKERS = 10 

In [8]:
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

In [4]:
print(f"PyTorch version: {torch.__version__}")

# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
print(f"Is MPS available? {torch.backends.mps.is_available()}")

# Set the device      
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

PyTorch version: 2.2.2
Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Using device: mps


In [31]:
train_transform = transforms.Compose([
           torchvision.transforms.ToTensor(),
           torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        
test_transform = transforms.Compose([
              torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]) 

train_dataset = datasets.CIFAR10(root = '../data', 
                                 train = True, 
                                 transform = train_transform, 
                                 download = True)

test_dataset = datasets.CIFAR10(root = '../data', 
                                train = False, 
                                transform = test_transform, 
                                download = True)
    

train_dataloader = DataLoader(train_dataset, 
                              batch_size = BATCH_SIZE, 
                              shuffle = True,
                              num_workers = NUM_WORKERS, 
                              drop_last = True)


test_dataloader = DataLoader(test_dataset, 
                             batch_size = BATCH_SIZE, 
                             shuffle = False, 
                            num_workers = NUM_WORKERS, 
                             drop_last = True)

Files already downloaded and verified
Files already downloaded and verified


In [32]:
class_names = train_dataset.classes
class_names

['airplane',
 'automobile',
 'bird',
 'cat',
 'deer',
 'dog',
 'frog',
 'horse',
 'ship',
 'truck']

In [33]:
len(train_dataset), len(test_dataset)

(50000, 10000)

In [34]:
class PyTorchVGG16(nn.Module):

    def __init__(self, num_classes):
        super().__init__()
        
        
        self.block_1 = nn.Sequential(
                nn.Conv2d(in_channels=3,
                          out_channels=64,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          # (1(32-1)- 32 + 3)/2 = 1
                          padding=1), 
                nn.ReLU(),
                nn.Conv2d(in_channels=64,
                          out_channels=64,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )
        
        self.block_2 = nn.Sequential(
                nn.Conv2d(in_channels=64,
                          out_channels=128,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=128,
                          out_channels=128,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )
        
        self.block_3 = nn.Sequential(        
                nn.Conv2d(in_channels=128,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=256,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),        
                nn.Conv2d(in_channels=256,
                          out_channels=256,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )
          
        self.block_4 = nn.Sequential(   
                nn.Conv2d(in_channels=256,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),        
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),        
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),            
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))
        )
        
        self.block_5 = nn.Sequential(
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),            
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),            
                nn.Conv2d(in_channels=512,
                          out_channels=512,
                          kernel_size=(3, 3),
                          stride=(1, 1),
                          padding=1),
                nn.ReLU(),    
                nn.MaxPool2d(kernel_size=(2, 2),
                             stride=(2, 2))             
        )
        
        self.features = nn.Sequential(
            self.block_1, self.block_2, 
            self.block_3, self.block_4, 
            self.block_5
        )
            
        self.classifier = nn.Sequential(
            nn.Linear(512, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(4096, num_classes),
        )
             
        # self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                #n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                #m.weight.data.normal_(0, np.sqrt(2. / n))
                m.weight.detach().normal_(0, 0.05)
                if m.bias is not None:
                    m.bias.detach().zero_()
            elif isinstance(m, torch.nn.Linear):
                m.weight.detach().normal_(0, 0.05)
                m.bias.detach().detach().zero_()
        
    def forward(self, x):

        x = self.features(x)
        # x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        logits = self.classifier(x)

        return logits

In [35]:
from tqdm import tqdm

def train(model, train_dataloader, test_dataloader, epochs, criterion, optimizer, device, scheduler = None):
    """
    This function trains the model
    
    Parameters
    ----------
    model: nn.Module
    train_dataloader: DataLoader
    test_dataloader: DataLoader
    epochs: int
    criterion: nn.Module
    optimizer: nn.Module
    device: str
    """
    
    model.to(device)
    
    results = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
    
    for epoch in tqdm(range(epochs)):
        model.train()
        running_loss, running_acc = 0.0, 0.0
        
        for batch, (inputs, labels) in enumerate(train_dataloader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            running_acc += (outputs.argmax(1) == labels).float().mean()
            
        scheduler.step()
        epoch_loss = running_loss / len(train_dataloader)
        epoch_acc = running_acc / len(train_dataloader)
        
        logger.info(f'Epoch: {epoch} Training Loss: {epoch_loss}, Training Accuracy: {epoch_acc}')
        
        model.eval()
        test_loss, test_acc = 0.0, 0.0
        
        for batch, (inputs, labels) in enumerate(test_dataloader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            test_loss += loss.item()
            test_acc += (outputs.argmax(1) == labels).float().mean()
            
        epoch_loss = running_loss / len(test_dataloader)
        epoch_acc = running_acc / len(test_dataloader)
        
        logger.info(f'Epoch: {epoch} Validation Loss: {epoch_loss}, Validation Accuracy: {epoch_acc}')
    
        results['train_loss'].append(running_loss)
        results['train_acc'].append(running_acc)
        results['val_loss'].append(test_loss)
        results['val_acc'].append(test_acc)
        
    return model, results

In [36]:
pytorch_model = PyTorchVGG16(num_classes = 10)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(pytorch_model.parameters(), lr = LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma = 0.1)

pytorch_model, results = train(pytorch_model, train_dataloader, test_dataloader, NUM_EPOCHS, criterion, optimizer, device, scheduler)

  0%|          | 0/5 [00:00<?, ?it/s]2024-04-01 11:55:32,737 - __main__ - INFO - Epoch: 0 Training Loss: 99908.02902829586, Training Accuracy: 0.1004006415605545
2024-04-01 11:56:37,161 - __main__ - INFO - Epoch: 0 Validation Loss: 499540.14514147927, Validation Accuracy: 0.5020031929016113
 20%|██        | 1/5 [02:35<10:23, 155.82s/it]2024-04-01 11:58:03,467 - __main__ - INFO - Epoch: 1 Training Loss: 2.303011526205601, Training Accuracy: 0.10096153616905212
2024-04-01 11:59:07,588 - __main__ - INFO - Epoch: 1 Validation Loss: 11.515057631028004, Validation Accuracy: 0.504807710647583
 40%|████      | 2/5 [05:06<07:37, 152.65s/it]2024-04-01 12:00:34,369 - __main__ - INFO - Epoch: 2 Training Loss: 2.302917050092648, Training Accuracy: 0.09841746836900711
2024-04-01 12:01:39,153 - __main__ - INFO - Epoch: 2 Validation Loss: 11.514585250463242, Validation Accuracy: 0.49208733439445496
 60%|██████    | 3/5 [07:37<05:04, 152.15s/it]2024-04-01 12:03:05,393 - __main__ - INFO - Epoch: 3 Train

In [37]:
results

{'train_loss': [19482065.660517693,
  449.08724761009216,
  449.0688247680664,
  449.05244517326355,
  449.0285096168518],
 'train_acc': [tensor(19.5781, device='mps:0'),
  tensor(19.6875, device='mps:0'),
  tensor(19.1914, device='mps:0'),
  tensor(19.3242, device='mps:0'),
  tensor(19.5781, device='mps:0')],
 'val_loss': [89.81364727020264,
  89.80262422561646,
  89.80246043205261,
  89.80243444442749,
  89.80243253707886],
 'val_acc': [tensor(3.8984, device='mps:0'),
  tensor(3.8984, device='mps:0'),
  tensor(3.8984, device='mps:0'),
  tensor(3.8984, device='mps:0'),
  tensor(3.8984, device='mps:0')]}