# My first CNN

In [1]:
import sys
sys.path.append("..")

In [2]:
# Self imports
from src.helpers import *

In [3]:
import torch,os
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from ignite.engine import (
    Engine,
    Events,
    create_supervised_trainer,
    create_supervised_evaluator,
)
from ignite.metrics import Accuracy, Loss
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import TensorboardLogger, global_step_from_engine

## Step 1 - Grab your device

In [4]:
device = (
    torch.accelerator.current_accelerator().type
    if torch.accelerator.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


## Step 2 - Formulate your data

In [5]:
# Create test and train data
if not os.path.isdir("data") or not os.listdir("data"):
    create_data_dirs()
    create_circle(TypeOfData.TRAIN, 5000)
    create_triangle(TypeOfData.TRAIN, 5000)
    create_circle(TypeOfData.TEST, 1000)
    create_triangle(TypeOfData.TEST, 1000)

# Convert into Pytorch data objects
transform = transforms.Compose(
    [
        # Currently this mode does not care about color - We are gonna keep generating rgb because who knows
        transforms.Grayscale(),
        # Convert [0,255] to [0,1]
        transforms.ToTensor(),
        # Convert to [-1,1] which in turn makes shape detection stronger. Mote - In thesis I should calculate this before hand but I mean batch norm should apply after a while
        transforms.Normalize(mean=0.5, std=0.5),
    ]
)
train_dataset = datasets.ImageFolder("data/train", transform=transform)
test_dataset = datasets.ImageFolder("data/test", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

## Step 3 - Formulate the model, define optimizer, loss fn

In [6]:
# Create the model and move it to the gpu
model = ShapeCNN().to(device)
# Create it is loss function and define it is optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

## Step 4 - Setup ignite

In [7]:
# Pass the presetted variables beforehand
trainer = create_supervised_trainer(model, optimizer, loss_fn, device)

val_metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(loss_fn)
}

train_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
val_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

In [8]:
log_interval = 100
# Logging for your config
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
    print(f"Epoch[{engine.state.epoch}], Iter[{engine.state.iteration}] Loss: {engine.state.output:.2f}")

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    train_evaluator.run(train_loader)
    metrics = train_evaluator.state.metrics
    print(f"Training Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")


@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    val_evaluator.run(test_loader)
    metrics = val_evaluator.state.metrics
    print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}")


# Score function to return current value of any metric we defined above in val_metrics
def score_function(engine):
    return engine.state.metrics["accuracy"]

In [9]:
# Checkpoint to store n_saved best models wrt score function
model_checkpoint = ModelCheckpoint(
    "checkpoint",
    n_saved=2,
    filename_prefix="best",
    score_function=score_function,
    score_name="accuracy",
    global_step_transform=global_step_from_engine(trainer), # helps fetch the trainer's state,
    require_empty=False
)

In [10]:
# Save the model after every epoch of val_evaluator is completed
val_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model})

<ignite.engine.events.RemovableEventHandle at 0x7dc3c012ef90>

In [11]:
# Define a Tensorboard logger
tb_logger = TensorboardLogger(log_dir="tb-logger")

# Attach handler to plot trainer's loss every 100 iterations
tb_logger.attach_output_handler(
    trainer,
    event_name=Events.ITERATION_COMPLETED(every=log_interval),
    tag="training",
    output_transform=lambda loss: {"batch_loss": loss},
)

# Attach handler for plotting both evaluators' metrics after every epoch completes
for tag, evaluator in [("training", train_evaluator), ("validation", val_evaluator)]:
    tb_logger.attach_output_handler(
        evaluator,
        event_name=Events.EPOCH_COMPLETED,
        tag=tag,
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer),
    )

In [12]:
trainer.run(train_loader, max_epochs=5)

Epoch[1], Iter[100] Loss: 0.08
Training Results - Epoch[1] Avg accuracy: 0.98 Avg loss: 0.04
Validation Results - Epoch[1] Avg accuracy: 0.98 Avg loss: 0.05
Epoch[2], Iter[200] Loss: 0.01
Epoch[2], Iter[300] Loss: 0.00
Training Results - Epoch[2] Avg accuracy: 0.74 Avg loss: 0.94
Validation Results - Epoch[2] Avg accuracy: 0.74 Avg loss: 0.93
Epoch[3], Iter[400] Loss: 0.00
Training Results - Epoch[3] Avg accuracy: 1.00 Avg loss: 0.00
Validation Results - Epoch[3] Avg accuracy: 1.00 Avg loss: 0.00
Epoch[4], Iter[500] Loss: 0.00
Epoch[4], Iter[600] Loss: 0.00
Training Results - Epoch[4] Avg accuracy: 1.00 Avg loss: 0.00
Validation Results - Epoch[4] Avg accuracy: 1.00 Avg loss: 0.00
Epoch[5], Iter[700] Loss: 0.00
Training Results - Epoch[5] Avg accuracy: 1.00 Avg loss: 0.00
Validation Results - Epoch[5] Avg accuracy: 1.00 Avg loss: 0.00


State:
	iteration: 785
	epoch: 5
	epoch_length: 157
	max_epochs: 5
	output: 0.00013986413250677288
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

In [13]:
tb_logger.close()

%load_ext tensorboard

%tensorboard --logdir=.

Reusing TensorBoard on port 6006 (pid 28502), started 1:53:37 ago. (Use '!kill 28502' to kill it.)

In [14]:
# At last we can view our best models
!ls checkpoint


'best_model_3_accuracy=0.9985.pt'  'best_model_4_accuracy=1.0000.pt'
'best_model_4_accuracy=0.9970.pt'  'best_model_5_accuracy=1.0000.pt'
'best_model_4_accuracy=0.9990.pt'
