In [1]:
import os
import json
import sys

root_dir = os.path.join(os.getcwd(), "..")
sys.path.append(root_dir)

In [2]:
from src.data.get_dataloader import get_dataloaders
from src.model.get_model import get_model
from src.visualization.acc_loss import plot_acc_loss
from src.model.init_weights import init_weights
from src.experiment.utils import one_hot_encode
from src.model.model_files import save_model

In [3]:
import torch
import torch.nn as nn
from torchvision import models

from tqdm import tqdm
import numpy as np

In [4]:
config_path = os.path.join("..", "config", "example_config.json")
with open(config_path) as json_file:
    config = json.load(json_file)

In [5]:
config['model']['type'] = 'vgg11'
config['model']['save_path'] = './models/vgg11.pth'
config['data']['dataset'] = 'cifar10'

In [6]:
def compute_loss_accuracy(logits, labels, criterion):
    loss = criterion(logits, labels)
    pred = logits.argmax(dim=1, keepdim=True)
    correct = pred.eq(labels.view_as(pred)).sum().item()
    return loss, correct

# Prepare to train

In [7]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [8]:
train_loader, val_loader, test_loader = get_dataloaders(config)

Creating dataloaders...
Files already downloaded and verified
Files already downloaded and verified
Dataloaders created


In [9]:
model = models.vgg11(weights='DEFAULT')
model.classifier[6] = nn.Linear(4096, 10, bias=True)
model = model.to(device)

In [10]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(
    model.parameters(),
    lr = config['optimizer']['lr'],
    momentum = config['optimizer']['momentum'],
    weight_decay = config['optimizer']['weight_decay'],
    nesterov = config['optimizer']['nestrov']
)

# Train the modified model for CIFAR-10

In [11]:
config

{'experiment_name': 'test',
 'data': {'path': '../data',
  'dataset': 'cifar10',
  'image_channels': 1,
  'num_classes': 10,
  'batch_size': 512,
  'num_workers': 2,
  'val_split': 0.2},
 'model': {'type': 'vgg11',
  'num_blocks': [2, 2, 2, 2],
  'save_path': './models/vgg11.pth',
  'init_method': 'normal',
  'init_mean': 0.0,
  'init_std': 0.0001},
 'optimizer': {'type': 'sgd',
  'lr': 0.0001,
  'momentum': 0.9,
  'weight_decay': 0.0001,
  'nestrov': False},
 'training': {'epochs': 15, 'criterion': 'cross_entropy'},
 'visualize': {'perform': False, 'save_path': './reports/figures/'}}

In [13]:
epoch_describer = tqdm(range(config['training']['epochs']), desc=f"Train", ncols=100)

train_losses = []
train_accs = []
val_losses = []
val_accs = []

for epoch in epoch_describer:

    train_loss = 0.0
    train_correct = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        logits = model(inputs)
        loss, correct = compute_loss_accuracy(logits, labels, criterion)
        train_loss += loss.detach().cpu().item()
        train_correct += correct

        # training
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss = train_loss / len(train_loader)
    train_acc = train_correct * 100 / len(train_loader.dataset)

    train_losses.append(train_loss)
    train_accs.append(train_acc)

    epoch_describer.set_description(f"Train (loss={train_loss:.3f}, acc={np.mean(train_acc):.3f})")

    with torch.no_grad():
        val_loss = 0.0
        val_correct = 0.0
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits = model(inputs)
            loss, correct = compute_loss_accuracy(logits, labels, criterion)
            val_loss += loss.detach().cpu().item()
            val_correct += correct

        val_loss = val_loss / len(val_loader)
        val_acc = val_correct * 100 / len(val_loader.dataset)

        val_losses.append(val_loss)
        val_accs.append(val_acc)
    

Train:   0%|                                                                 | 0/15 [01:01<?, ?it/s]


RuntimeError: MPS backend out of memory (MPS allocated: 15.70 GB, other allocations: 3.37 GB, max allowed: 18.13 GB). Tried to allocate 784.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
save_model(model, config)