In [None]:
# !pip install -U "git+https://github.com/ab7289-tandon-nyu/csgy6953_DeepLearning_Midterm.git"

!git clone -b auto_augment "https://github.com/ab7289-tandon-nyu/csgy6953_DeepLearning_Midterm.git"
!cp -r /content/csgy6953_DeepLearning_Midterm/src/ .

In [None]:
# 
!pip install wandb
!wandb login "996181dd165ce17c309c3d027297e4ed8952f4ec"

In [None]:
import torch
import torch.nn as nn

import time
import random

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
SEED = 1234

random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from src.data import get_transformed_data, make_data_loaders
from src.transforms import make_auto_transforms

batch_size = 512

valid_ratio = 0.1
train_data, valid_data, test_data = (
    get_transformed_data(
        make_transforms = make_auto_transforms, 
        valid_ratio = valid_ratio
    )
)

train_iter, valid_iter, test_iter = (
    make_data_loaders(
        train_data,
        valid_data,
        test_data,
        batch_size
    )
)

**Define our Model

In [None]:
from src.model import ResNet, StemConfig
from src.utils import initialize_parameters, epoch_time

architecture = [
    (1, 64, 0.5),
    (1, 128, 0.5),
    (1, 256, 0.5),
    (1, 512, 0.5),
]

config = StemConfig(num_channels=64, kernel_size=3, stride=1, padding=1)
model = ResNet(architecture, stem_config=config, output_size=10)

In [None]:
from pathlib import Path

# path = "/content/drive/MyDrive/School/Tandon MSCS/Classes/CS-GY 6953: Deep Learning/midterm/"
path = "/content/drive/MyDrive/Colab Notebooks/midterm/"
file_path = path + "resnet_alex_49m_auto_augment.pt"

model_file = Path(file_path)

Need to run a dummy set of data to initialize the lazy modules before we can use torchsummary

In [None]:
if model_file.exists() and model_file.is_file():
  # load our previously trained model
  model.load_state_dict(torch.load(model_file))
else:
  # intialize a new model
  inputs = torch.empty((batch_size, 3, 32, 32))
  inputs.normal_()
  model = model.to(device)
  y = model(inputs.to(device))
  print(y.size())

  model.apply(initialize_parameters)

In [None]:
print(f"num params: {sum([p.numel() for p in model.parameters() if p.requires_grad]):,}")

In [None]:
from src.engine import train_one_epoch, evaluate

best_loss = float('inf')
EPOCHS  = 100
learning_rate = 1e-3
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# setup wandb logging
import wandb

wandb.init(
    project='ResNet_5M',
    name="resnet_alex_49m_auto_augment",
    entity="dlf22_mini_project",
    config={
        "learning_rate":learning_rate,
        "epochs": EPOCHS,
        "batch_size": batch_size,
        "architecture": architecture
        }
)

In [None]:
for epoch in range(1, EPOCHS+1):
    start = time.time()

    print(f"Epoch {epoch}")
    train_loss, train_acc = train_one_epoch(model, train_iter, criterion, optimizer, device)
    train_mins, train_secs = epoch_time(start, time.time())

    wandb.log({
        "train_loss": train_loss,
        "train_acc": train_acc,
        "epoch": epoch
    })

    print(f"\tTrain elapsed: {train_mins}:{train_secs}, loss: {train_loss:.4f}, acc: {train_acc * 100:.2f}%")

    start = time.time()
    val_loss, val_acc = evaluate(model, valid_iter, criterion, device)
    val_mins, val_secs = epoch_time(start, time.time())

    wandb.log({
        "val_loss": val_loss,
        "val_acc": val_acc,
        "epoch": epoch,
    })

    print(f"\tValidation elapsed: {val_mins}:{val_secs}, loss: {val_loss:.4f}, acc: {val_acc * 100:.2f}%")

    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), file_path)

## Evaluate the Model  

In [None]:
model.load_state_dict(torch.load(file_path))
test_loss, test_acc = evaluate(model.to(device), test_iter, criterion, device)
print(f"Test Loss: {test_loss:.4f}\nTest Accuracy: {test_acc * 100:.2f}%")

wandb.log({
    "test_loss": test_loss,
    "test_acc": test_acc,
})