In [1]:
import os
import copy
import ast
import time
import json
import random
import glob
import numpy as np
from functools import partial
from collections import Counter
from tqdm import tqdm_notebook as tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
#from torchsummary import summary

In [2]:
lr = 0.001
n_epoch = 100
batch_size = 256
n_cls = 34
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
from ICJAIIterableDataset import ICJAIIterableDataset
from model import MJResNet50

In [4]:
train_set = ICJAIIterableDataset('processed_data/discard_tile_train.nosync.txt', history_len=4)
val_set = ICJAIIterableDataset('processed_data/discard_tile_val.nosync.txt', history_len=4, data_ratio=0.0025)

train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size)

val_loader = DataLoader(dataset=val_set,
                          batch_size=batch_size)

In [5]:
def compute_acc(pred, target):
    '''
    Args:
    - pred (torch.tensor, float32): unnormalized logits (before softmax) shape [bs, n_cls]
    - target (torch.tensor, int64): shape [bs]

    Returns:
    - acc (float): exact classification accuracy
    '''
    pred = torch.argmax(F.softmax(pred, dim=-1), dim=-1)

    # Exact accuracy
    acc = (pred == target).sum()/(target.shape[0])
    return acc.item()

def validate(model, val_loader, epoch):
    model.eval()
    val_acc = 0
    pbar = tqdm(val_loader, desc=f"Epoch {epoch} Validation")
    with torch.no_grad():
        for bi, (X, Y) in enumerate(pbar):
            # Forward
            X, Y = list(map(lambda x: x.to(device), [X, Y]))
            preds = model(X)

            acc = compute_acc(preds.cpu(), Y.cpu())
            val_acc += acc

    val_acc /= (bi+1)
    pbar.set_postfix_str(f'Val Acc: {val_acc:.4f}')
    return val_acc

def save_checkpoints(epoch, model, optimizer, train_loss, val_acc):
    save_dir = 'discard_ckpts/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, f'ep{epoch}-val_acc{val_acc:.4f}.tar')
    torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, save_path)
    print('Checkpoint saved')

In [6]:
# Model
model = MJResNet50(history_len=4, n_cls=34)
model.to(device)

# Loss function
# cls_weights = [1/train_set.cls_ratios[i] for i in range(n_cls)]
# cls_weights = torch.tensor([w/sum(cls_weights) for w in cls_weights])
# criterion = nn.CrossEntropyLoss(cls_weights)
criterion = nn.CrossEntropyLoss()
criterion.to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Scheduler
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=n_epoch, div_factor=20)

In [7]:
# train from scratch
best_val_acc = 0
ep_start = 0

for epoch in range(ep_start, n_epoch):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{n_epoch}")
    train_loss, train_acc = 0, 0
    for bi, (X, Y) in enumerate(pbar):
        optimizer.zero_grad()

        # Forward
        X, Y = list(map(lambda x: x.to(device), [X, Y]))
        preds = model(X)

        # Calculate loss & update
        loss = criterion(preds, Y)
        loss.backward()
        optimizer.step()
        #scheduler.step()

        train_loss += loss.detach().item()
        acc = compute_acc(preds.cpu(), Y.cpu())
        train_acc += acc

        pbar.set_postfix_str(f'Train loss: {loss.detach().item():.4f} | Train Acc: {(train_acc/(bi+1)):.4f}')

    # End of epoch
    train_loss /= (bi+1)
    train_acc /= (bi+1)

    val_acc = validate(model, val_loader, epoch)

    if val_acc >= best_val_acc:
        best_val_acc = val_acc
        save_checkpoints(epoch, model, optimizer, train_loss, val_acc)

    pbar.set_postfix_str(f'Train loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  import sys


HBox(children=(HTML(value='Epoch 0/100'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px…




KeyboardInterrupt: 