In [1]:
from dataset import get_dataloader
from model import get_mobilenet_v3, freeze_model, unfreeze_model
from utils import add_weight_decay

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import time
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import Dict, List

torch.backends.cudnn.benchmark = True


# Hyper parameters
num_classes = 3
num_epoch = 10
batch_size = 1
lr = 1e-4
weight_decay = 1e-6
warmup_epoch = 3
cache = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize Model
model = get_mobilenet_v3(num_classes=num_classes)
model = model.to(device)
# Initialize dataloader
train_loader = get_dataloader('train', batch_size, num_classes=num_classes, cache=cache)
val_loader = get_dataloader('val', batch_size, num_classes=num_classes, cache=cache)
# Initialize Optimizer
optimizer = torch.optim.Adam(add_weight_decay(model), lr = lr)
# Loss function
criterion = nn.CrossEntropyLoss(reduction='mean')

# def compute_metrics(gt, predict) -> Dict[str, float]:
#     TP = ((gt == predict) & (gt == 1)).sum()
#     FP = ((gt != predict) & (gt == 1)).sum()
#     TN = ((gt == predict) & (gt == 1)).sum()
#     FN = ((gt == predict) & (gt == 1)).sum()

for epoch in range(num_epoch):
    print(f'epoch = {epoch}')
    start_time = time.time()
    optimizer.zero_grad(set_to_none=True)
    train_losses = 0.0
    train_accs = 0.0
    # Training
    model.training = True

    # Freeze feature layer for warm up
    if epoch < warmup_epoch:
        freeze_model(model)
    else:
        unfreeze_model(model)
    for i, sample in enumerate(tqdm(train_loader)):
        img, label = sample['img'].to(device), sample['label'].to(device)
        outputs = model(img)
        loss = criterion(outputs, label)
        loss.backward()
        train_losses += loss.item()
        
        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        # Calculate accuracy
        label = torch.argmax(label, dim=1).detach().cpu()
        predicted_label = torch.argmax(outputs, dim=1).detach().cpu()
        acc = (label == predicted_label).sum() / len(label)

        train_accs += acc
        del loss
        break
    print('Training Loss : {:.4f}'.format(train_losses / len(train_loader)))
    print('Training Accuracy : {:.4f}'.format(train_accs / len(train_loader)))
    # Validation
    val_losses = 0.0
    val_accs = 0.0
    with torch.no_grad():
        model.eval()
        for sample in tqdm(val_loader):
            img, label = sample['img'].to(device), sample['label'].to(device)
            outputs = model(img)
            loss = criterion(outputs, label)
            val_losses += loss.item()

            # Calculate accuracy
            label = torch.argmax(label, dim=1).detach().cpu()
            predicted_label = torch.argmax(outputs, dim=1).detach().cpu()
            acc = (label == predicted_label).sum() / len(label)
    print('Validating Loss : {:.4f}'.format(val_losses / len(val_loader)))
    print('Validating Accuracy : {:.4f}'.format(val_accs / len(train_loader)))

  f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional "


epoch = 0


  0%|          | 0/300 [00:13<?, ?it/s]


Training Loss : 0.0037
Training Accuracy : 0.0000


  0%|          | 0/150 [00:00<?, ?it/s]

torch.Size([1, 3, 1024, 1024])


  1%|          | 1/150 [00:04<11:32,  4.65s/it]

torch.Size([1, 3, 1024, 1024])


  1%|▏         | 2/150 [00:05<05:49,  2.36s/it]

torch.Size([1, 3, 1024, 1024])


  2%|▏         | 3/150 [00:06<04:14,  1.73s/it]

torch.Size([1, 3, 1024, 1024])


  3%|▎         | 4/150 [00:07<03:21,  1.38s/it]

torch.Size([1, 3, 1024, 1024])


  3%|▎         | 5/150 [00:07<02:47,  1.16s/it]

torch.Size([1, 3, 1024, 1024])


  4%|▍         | 6/150 [00:08<02:31,  1.05s/it]

torch.Size([1, 3, 1024, 1024])


  5%|▍         | 7/150 [00:09<02:15,  1.06it/s]

torch.Size([1, 3, 1024, 1024])


  5%|▌         | 8/150 [00:10<01:59,  1.19it/s]

torch.Size([1, 3, 1024, 1024])


  6%|▌         | 9/150 [00:10<01:48,  1.30it/s]

torch.Size([1, 3, 1024, 1024])


  7%|▋         | 10/150 [00:11<01:43,  1.35it/s]

torch.Size([1, 3, 1024, 1024])


  7%|▋         | 11/150 [00:12<01:39,  1.40it/s]

torch.Size([1, 3, 1024, 1024])


  8%|▊         | 12/150 [00:12<01:37,  1.42it/s]

torch.Size([1, 3, 1024, 1024])


  9%|▊         | 13/150 [00:13<01:33,  1.46it/s]

torch.Size([1, 3, 1024, 1024])


  9%|▉         | 14/150 [00:14<01:29,  1.52it/s]

torch.Size([1, 3, 1024, 1024])


 10%|█         | 15/150 [00:14<01:30,  1.48it/s]

torch.Size([1, 3, 1024, 1024])


 11%|█         | 16/150 [00:15<01:28,  1.52it/s]

torch.Size([1, 3, 1024, 1024])


 11%|█▏        | 17/150 [00:15<01:25,  1.55it/s]

torch.Size([1, 3, 1024, 1024])


 12%|█▏        | 18/150 [00:16<01:25,  1.54it/s]

torch.Size([1, 3, 1024, 1024])


 13%|█▎        | 19/150 [00:17<01:30,  1.45it/s]

torch.Size([1, 3, 1024, 1024])


 13%|█▎        | 20/150 [00:18<01:28,  1.47it/s]

torch.Size([1, 3, 1024, 1024])


 14%|█▍        | 21/150 [00:18<01:24,  1.52it/s]

torch.Size([1, 3, 1024, 1024])


 15%|█▍        | 22/150 [00:19<01:22,  1.55it/s]

torch.Size([1, 3, 1024, 1024])


 15%|█▌        | 23/150 [00:20<01:26,  1.46it/s]

torch.Size([1, 3, 1024, 1024])


 16%|█▌        | 24/150 [00:20<01:23,  1.51it/s]

torch.Size([1, 3, 1024, 1024])


 17%|█▋        | 25/150 [00:21<01:21,  1.54it/s]

torch.Size([1, 3, 1024, 1024])


 17%|█▋        | 26/150 [00:21<01:17,  1.59it/s]

torch.Size([1, 3, 1024, 1024])


 18%|█▊        | 27/150 [00:22<01:18,  1.56it/s]

torch.Size([1, 3, 1024, 1024])


 19%|█▊        | 28/150 [00:23<01:17,  1.58it/s]

torch.Size([1, 3, 1024, 1024])


 19%|█▉        | 29/150 [00:23<01:14,  1.62it/s]

torch.Size([1, 3, 1024, 1024])


 20%|██        | 30/150 [00:24<01:13,  1.64it/s]

torch.Size([1, 3, 1024, 1024])


 21%|██        | 31/150 [00:24<01:13,  1.62it/s]

torch.Size([1, 3, 1024, 1024])


 21%|██▏       | 32/150 [00:25<01:12,  1.62it/s]

torch.Size([1, 3, 1024, 1024])


 22%|██▏       | 33/150 [00:26<01:12,  1.61it/s]

torch.Size([1, 3, 1024, 1024])


 23%|██▎       | 34/150 [00:26<01:11,  1.62it/s]

torch.Size([1, 3, 1024, 1024])


 23%|██▎       | 35/150 [00:27<01:11,  1.61it/s]

torch.Size([1, 3, 1024, 1024])


 24%|██▍       | 36/150 [00:28<01:09,  1.65it/s]

torch.Size([1, 3, 1024, 1024])


 25%|██▍       | 37/150 [00:28<01:11,  1.58it/s]

torch.Size([1, 3, 1024, 1024])


 25%|██▌       | 38/150 [00:29<01:10,  1.59it/s]

torch.Size([1, 3, 1024, 1024])


 26%|██▌       | 39/150 [00:29<01:06,  1.67it/s]

torch.Size([1, 3, 1024, 1024])


 27%|██▋       | 40/150 [00:30<01:05,  1.68it/s]

torch.Size([1, 3, 1024, 1024])


 27%|██▋       | 41/150 [00:31<01:04,  1.69it/s]

torch.Size([1, 3, 1024, 1024])


 28%|██▊       | 42/150 [00:31<01:02,  1.73it/s]

torch.Size([1, 3, 1024, 1024])


 29%|██▊       | 43/150 [00:32<01:00,  1.76it/s]

torch.Size([1, 3, 1024, 1024])


 29%|██▉       | 44/150 [00:32<01:00,  1.75it/s]

torch.Size([1, 3, 1024, 1024])


 30%|███       | 45/150 [00:33<01:00,  1.74it/s]

torch.Size([1, 3, 1024, 1024])


 31%|███       | 46/150 [00:33<00:58,  1.77it/s]

torch.Size([1, 3, 1024, 1024])


 31%|███▏      | 47/150 [00:34<00:58,  1.76it/s]

torch.Size([1, 3, 1024, 1024])


 32%|███▏      | 48/150 [00:34<00:57,  1.77it/s]

torch.Size([1, 3, 1024, 1024])


 33%|███▎      | 49/150 [00:35<00:57,  1.77it/s]

torch.Size([1, 3, 1024, 1024])


 33%|███▎      | 50/150 [00:36<00:56,  1.78it/s]

torch.Size([1, 3, 1024, 1024])


 34%|███▍      | 51/150 [00:36<00:55,  1.78it/s]

torch.Size([1, 3, 1024, 1024])


 35%|███▍      | 52/150 [00:37<00:56,  1.73it/s]

torch.Size([1, 3, 1024, 1024])


 35%|███▌      | 53/150 [00:37<00:56,  1.73it/s]

torch.Size([1, 3, 1024, 1024])


 36%|███▌      | 54/150 [00:38<00:55,  1.75it/s]

torch.Size([1, 3, 1024, 1024])


 37%|███▋      | 55/150 [00:39<00:54,  1.74it/s]

torch.Size([1, 3, 1024, 1024])


 37%|███▋      | 56/150 [00:39<00:55,  1.70it/s]

torch.Size([1, 3, 1024, 1024])


 38%|███▊      | 57/150 [00:40<00:58,  1.60it/s]

torch.Size([1, 3, 1024, 1024])


 39%|███▊      | 58/150 [00:40<00:57,  1.60it/s]

torch.Size([1, 3, 1024, 1024])


 39%|███▉      | 59/150 [00:41<00:56,  1.61it/s]

torch.Size([1, 3, 1024, 1024])


 40%|████      | 60/150 [00:42<00:55,  1.63it/s]

torch.Size([1, 3, 1024, 1024])


 41%|████      | 61/150 [00:42<00:53,  1.65it/s]

torch.Size([1, 3, 1024, 1024])


 41%|████▏     | 62/150 [00:43<00:55,  1.57it/s]

torch.Size([1, 3, 1024, 1024])


 42%|████▏     | 63/150 [00:44<00:57,  1.52it/s]

torch.Size([1, 3, 1024, 1024])


 43%|████▎     | 64/150 [00:44<00:55,  1.54it/s]

torch.Size([1, 3, 1024, 1024])


 43%|████▎     | 65/150 [00:45<00:54,  1.57it/s]

torch.Size([1, 3, 1024, 1024])


 44%|████▍     | 66/150 [00:46<00:54,  1.55it/s]

torch.Size([1, 3, 1024, 1024])
