In [None]:
!pip install -U "git+https://github.com/ab7289-tandon-nyu/csgy6953_DeepLearning_Midterm.git"

In [None]:
# connect to our wandb project
!pip install wandb
!wandb login "API_KEY"

In [None]:
import torch
import torch.nn as nn
import random

import time

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
SEED = 1234

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from src.data import get_transformed_data, make_data_loaders
from src.transforms import make_transforms

BATCH_SIZE = 512
valid_ratio = 0.1

train_data, valid_data, test_data = (
    get_transformed_data(
        make_transforms = make_transforms,
        valid_ratio = valid_ratio
    )
)

train_iterator, valid_iterator, test_iterator = (
    make_data_loaders(
        train_data,
        valid_data,
        test_data,
        batch_size=BATCH_SIZE,
    )
)

**Define our Model

In [None]:
from src.model import ResNet, StemConfig
from src.utils import initialize_parameters, epoch_time

model_architecture = (
    (1, 128),
    (2, 128),
    (2, 128),
    (2, 128),
    (2, 196),
    (2, 196),
)

stem_config = StemConfig(num_channels=128, kernel_size=5, stride=1, padding=2)
model = ResNet(model_architecture, stem_config=stem_config, output_size=10)

In [None]:
from pathlib import Path

path = "/TODO/"
file_path = path + "TODO.pt"

model_file = Path(file_path)

Need to run a dummy set of data to initialize the lazy modules before we can use torchsummary

In [None]:
if model_file.exists() and model_file.is_file():
  print("loading model")
  # load our previously trained model
  model.load_state_dict(torch.load(model_file))
  model = model.to(device)
else:
  # intialize a new model
  print("init new model parameters")
  inputs = torch.empty((BATCH_SIZE, 3, 32, 32))
  inputs.normal_()
  model = model.to(device)
  y = model(inputs.to(device))
  print(y.size())

  model.apply(initialize_parameters)

In [None]:
from src.utils import count_parameters

num_params, grad_params = count_parameters(model)
print(f"There are {grad_params:,} trainable parameters.")

In [None]:
from src.engine import train_one_epoch, evaluate

best_loss = float('inf')
EPOCHS  = 100
learning_rate = 1e-3
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

if model_file.is_file():
  # if we loaded a previously saved iteration, we want to get the current
  # best loss otherwise we could overwrite our save with a worse model
  loss, acc = evaluate(model.to(device), test_iterator, criterion, device)
  best_loss = loss
  print(f"Prevous best loss: {loss:.4f}, acc: {acc * 100:.2f}%")

In [None]:
# setup wandb logging
import wandb

wandb.init(
    project='ResNet_5M',
    name="resnet_alex_49m_dropout",
    entity="dlf22_mini_project",
    config={
        "learning_rate":learning_rate,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "architecture": model_architecture,
        "avg_pool": 4,
        }
)

In [None]:
for epoch in range(1, EPOCHS+1):
    start = time.time()

    print(f"Epoch {epoch}")
    train_loss, train_acc = train_one_epoch(model, train_iterator, criterion, optimizer, device)
    train_mins, train_secs = epoch_time(start, time.time())

    wandb.log({
        "train_loss": train_loss,
        "train_acc": train_acc,
        "epoch": epoch
    })

    print(f"\tTrain elapsed: {train_mins}:{train_secs}, loss: {train_loss:.4f}, acc: {train_acc * 100:.2f}%")

    start = time.time()
    val_loss, val_acc = evaluate(model, valid_iterator, criterion, device)
    val_mins, val_secs = epoch_time(start, time.time())

    wandb.log({
        "val_loss": val_loss,
        "val_acc": val_acc,
        "epoch": epoch,
    })

    print(f"\tValidation elapsed: {val_mins}:{val_secs}, loss: {val_loss:.4f}, acc: {val_acc * 100:.2f}%")

    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), file_path)

## Evaluate the Model  

In [None]:
model.load_state_dict(torch.load(file_path))
test_loss, test_acc = evaluate(model.to(device), test_iterator, criterion, device)
print(f"Test Loss: {test_loss:.4f}\nTest Accuracy: {test_acc * 100:.2f}%")

wandb.log({
    "test_loss": test_loss,
    "test_acc": test_acc,
})