In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchmetrics
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import CSVLogger

import os
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

pl.seed_everything(0)

Seed set to 0


0

In [16]:
# load datasets
DATA_PATH = './data/stl-10'

class DataModule(pl.LightningDataModule):
    def __init__(self, data_path=DATA_PATH):
        super().__init__()
        self.data_path = data_path

    def prepare_data(self):
        datasets.STL10(self.data_path, split='train', download=True)
        datasets.STL10(self.data_path, split='test', download=True)

        self.data_transform = transforms.Compose([
            transforms.Resize((227, 227)),
            transforms.ToTensor()])
        return
    
    def setup(self, stage=None):
        train = datasets.STL10(DATA_PATH, split='train', download=True, transform=self.data_transform)
        self.test = datasets.STL10(DATA_PATH, split='test', download=True, transform=self.data_transform)

        self.train, self.valid = random_split(train, lengths=[4000, 1000])

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=32)
    
    def val_dataloader(self):
        return DataLoader(self.valid, batch_size=32)
    
torch.manual_seed(0)
data_module = DataModule()

In [17]:
class AlexNet(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNet, self).__init__()
        # input_size = (batch_size x 3 x 227 x 227)

        self.net = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=0),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(96, 256, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),

            nn.Conv2d(384, 384, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),

            nn.Conv2d(384, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(3, 2)
        )

        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=(256*6*6), out_features=4096),
            nn.ReLU(),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=4096, out_features=4096),
            nn.ReLU(),
            nn.Linear(in_features=4096, out_features=num_classes)
        )

        self.init_weight()


    def init_weight(self):
        for layer in self.net:
            if isinstance(layer, nn.Conv2d): # if layer==Conv
                nn.init.normal_(layer.weight, mean=0, std=0.01) # then reset weights by N(0, 0.01) and bias=0
                nn.init.constant_(layer.bias, 0)
        nn.init.constant_(self.net[4].bias, 1)
        nn.init.constant_(self.net[10].bias, 1)
        nn.init.constant_(self.net[12].bias, 1)  # 5th, 11th, 13th layers' bias reset as 1

    def forward(self, x):
        x = self.net(x)
        x = x.view(-1, 256 * 6 * 6)
        x = self.classifier(x)
        return x

In [18]:
class LightningModel(pl.LightningModule):
    def __init__(self, model, lr=0.001, num_classes=10):
        super().__init__()

        self.model = model
        self.lr = lr
        self.save_hyperparameters(ignore=['model'])
        self.train_acc = torchmetrics.Accuracy(task = 'MULTICLASS', num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task = 'MULTICLASS', num_classes=num_classes)
        self.test_acc = torchmetrics.Accuracy(task = 'MULTICLASS', num_classes=num_classes)

    def forward(self, x):
        return self.model(x)
    

    def _shared_step(self, batch):
        features, labels = batch
        logits = self.model(features)
        loss = F.cross_entropy(logits, labels)
        predicted_labels = torch.argmax(logits, dim=1)
        return loss, labels, predicted_labels
    
    def training_step(self, batch, batch_idx):
        loss, labels, predicted_labels = self._shared_step(batch)
        self.log('train_loss', loss)
        self.model.eval()
        with torch.no_grad():
            _, labels, predicted_labels = self._shared_step(batch)
        self.train_acc.update(predicted_labels, labels)
        self.log('train_acc', self.train_acc, on_epoch=True, on_step=False)
        self.model.train()
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, labels, predicted_labels = self._shared_step(batch)
        self.log('val_loss', loss)
        self.val_acc.update(predicted_labels, labels)
        self.log('val_acc', self.val_acc, on_epoch=True, on_step=False, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, labels, predicted_labels = self._shared_step(batch)
        self.test_acc.update(predicted_labels, labels)
        self.log('test_acc', self.test_acc, on_epoch=True, on_step=False)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
        return [optimizer], [scheduler]

In [20]:
model = AlexNet(num_classes=10)
lightning_model = LightningModel(model, lr=0.001, num_classes=10)

callbacks = [ModelCheckpoint(monitor='val_acc', mode='max'),]

os.makedirs('./pl_models', exist_ok=True)
logger = CSVLogger(save_dir='./pl_models', name='alexnet_logs')

trainer = pl.Trainer(
    max_epochs=10,
    logger=logger,
    callbacks=callbacks,
    accelerator='cpu',
    log_every_n_steps=5
)

start_time = time.time()
trainer.fit(lightning_model, data_module)

runtime = (time.time() - start_time) / 60
print(f"Training time: {runtime:.2f} minutes in total")

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/minkyunjung/miniforge3/envs/condavenv/lib/python3.10/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified



  | Name      | Type               | Params
-------------------------------------------------
0 | model     | AlexNet            | 58.3 M
1 | train_acc | MulticlassAccuracy | 0     
2 | val_acc   | MulticlassAccuracy | 0     
3 | test_acc  | MulticlassAccuracy | 0     
-------------------------------------------------
58.3 M    Trainable params
0         Non-trainable params
58.3 M    Total params
233.289   Total estimated model params size (MB)


Files already downloaded and verified
                                                                           

/Users/minkyunjung/miniforge3/envs/condavenv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 5:  84%|████████▍ | 105/125 [6:13:41<1:11:10,  0.00it/s, v_num=3, val_acc=0.093]Training time: 508.01 minutes in total


/Users/minkyunjung/miniforge3/envs/condavenv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [None]:
metrics = pd.read_csv(f'{trainer.logger.log_dir}/metrics.csv')

aggreg_metrics = []
agg_col = 'epoch'

for i, dfg in metrics.groupby(agg_col):
    agg = dict(dfg.mean())
    agg[agg_col] = i
    aggreg_metrics.append(agg)

df_metrics = pd.DataFrame(aggreg_metrics)
df_metrics[['train_loss', 'val_loss']].plot(
    grid=True, legend=True, xlabel='Epoch', ylabel='Loss')
df_metrics[['train_acc', 'val_acc']].plot(
    grid=True, legend=True, xlabel='Epoch', ylabel='Accuracy')

plt.show()