In [89]:
from dataloader import ESC_Dataset
import torch
from models.vit import ViT
import os
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import torch.nn as nn

In [90]:
full_dataset = np.load(os.path.join("esc-50", "esc-50-data.npy"), allow_pickle = True)

In [91]:
train_dataset = ESC_Dataset(dataset = full_dataset, esc_fold=0, eval_mode = False)
eval_dataset = ESC_Dataset(dataset = full_dataset, esc_fold=0, eval_mode = True)

In [92]:
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=4, collate_fn=None, pin_memory=False)
eval_dataloader = DataLoader(eval_dataset, batch_size=128, shuffle=False, num_workers=4, collate_fn=None, pin_memory=False)

In [93]:
model = ViT(
    image_size = (320,128),
    patch_size = (40,16),
    channels = 1,
    num_classes = 50,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.0,
    emb_dropout = 0.0
)

In [94]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [95]:
model = model.to(device)

In [96]:
def eval_model(model, eval_dataset, device):
    model.eval()
    forecast, true_labs = [], []
    with torch.no_grad():
        for data, labs in tqdm(eval_dataset):
            data, labs = data.to(device), labs[:,0].cpu()
            true_labs.append(labs)
            outputs = model(data)
            
            outputs = outputs.detach().cpu().numpy().argmax(axis=1)
            forecast.append(outputs)
    forecast = [x for sublist in forecast for x in sublist]
    true_labs = [x for sublist in true_labs for x in sublist]
    return f1_score(forecast, true_labs, average='macro'), accuracy_score(forecast, true_labs)

In [97]:
criterion = nn.CrossEntropyLoss()

In [98]:
n_epoch = 100
best_f1 = 0
lr = 1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(n_epoch):
    model.train()
    for data, labs in tqdm(train_dataloader):
        data, labs = data.to(device), labs.to(device)[:,0]
        outputs = model(data)
        loss = criterion(outputs, labs) 
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # del outputs
        # torch.cuda.empty_cache( )
#     if epoch % 10 == 0:
    f1, accuracy = eval_model(model, eval_dataloader, device)
    f1_train, accuracy_train = eval_model(model, train_dataloader, device)
    print(f'epoch: {epoch}, f1_test: {f1}, accuracy_test: {accuracy}, f1_train: {f1_train},  accuracy_train: {accuracy_train}')
    if f1 > best_f1:
        best_f1 = f1
        torch.save(model.state_dict(), 'classic_vit.pt')

    lr = lr * 0.95
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

100%|██████████| 13/13 [01:26<00:00,  6.64s/it]
100%|██████████| 4/4 [00:10<00:00,  2.68s/it]
100%|██████████| 13/13 [00:31<00:00,  2.45s/it]


epoch: 0, f1_test: 0.0008163265306122449, accuracy_test: 0.02, f1_train: 0.004280768706555813,  accuracy_train: 0.02375


100%|██████████| 13/13 [01:25<00:00,  6.61s/it]
100%|██████████| 4/4 [00:08<00:00,  2.21s/it]
100%|██████████| 13/13 [00:30<00:00,  2.36s/it]


epoch: 1, f1_test: 0.013016451646588632, accuracy_test: 0.05, f1_train: 0.0058397721390283505,  accuracy_train: 0.038125


100%|██████████| 13/13 [01:34<00:00,  7.26s/it]
100%|██████████| 4/4 [00:08<00:00,  2.16s/it]
100%|██████████| 13/13 [00:31<00:00,  2.39s/it]


epoch: 2, f1_test: 0.02458299814725023, accuracy_test: 0.07, f1_train: 0.013755366733368544,  accuracy_train: 0.050625


100%|██████████| 13/13 [01:43<00:00,  7.94s/it]
100%|██████████| 4/4 [00:08<00:00,  2.24s/it]
100%|██████████| 13/13 [00:30<00:00,  2.33s/it]


epoch: 3, f1_test: 0.0223861873909118, accuracy_test: 0.0675, f1_train: 0.03831591721499227,  accuracy_train: 0.088125


100%|██████████| 13/13 [01:43<00:00,  7.95s/it]
100%|██████████| 4/4 [00:11<00:00,  2.84s/it]
100%|██████████| 13/13 [00:31<00:00,  2.44s/it]


epoch: 4, f1_test: 0.02699964213368804, accuracy_test: 0.0725, f1_train: 0.043508502179321146,  accuracy_train: 0.10125


100%|██████████| 13/13 [01:42<00:00,  7.88s/it]
100%|██████████| 4/4 [00:09<00:00,  2.31s/it]
100%|██████████| 13/13 [00:30<00:00,  2.31s/it]


epoch: 5, f1_test: 0.06660081303392125, accuracy_test: 0.1, f1_train: 0.07343327472424675,  accuracy_train: 0.116875


100%|██████████| 13/13 [01:46<00:00,  8.16s/it]
100%|██████████| 4/4 [00:10<00:00,  2.53s/it]
100%|██████████| 13/13 [00:30<00:00,  2.33s/it]


epoch: 6, f1_test: 0.05593521564782408, accuracy_test: 0.0925, f1_train: 0.0685175119629104,  accuracy_train: 0.11625


100%|██████████| 13/13 [01:37<00:00,  7.47s/it]
100%|██████████| 4/4 [00:07<00:00,  1.94s/it]
100%|██████████| 13/13 [00:27<00:00,  2.09s/it]


epoch: 7, f1_test: 0.044222951689950614, accuracy_test: 0.09, f1_train: 0.07662943925676316,  accuracy_train: 0.13125


100%|██████████| 13/13 [01:39<00:00,  7.62s/it]
100%|██████████| 4/4 [00:10<00:00,  2.53s/it]
100%|██████████| 13/13 [00:26<00:00,  2.04s/it]


epoch: 8, f1_test: 0.06609817557700028, accuracy_test: 0.1025, f1_train: 0.11071982987185322,  accuracy_train: 0.153125


100%|██████████| 13/13 [01:38<00:00,  7.56s/it]
100%|██████████| 4/4 [00:07<00:00,  1.98s/it]
100%|██████████| 13/13 [00:26<00:00,  2.03s/it]


epoch: 9, f1_test: 0.08468035624930134, accuracy_test: 0.1225, f1_train: 0.15915913608115592,  accuracy_train: 0.19625


100%|██████████| 13/13 [01:48<00:00,  8.32s/it]
100%|██████████| 4/4 [00:07<00:00,  1.96s/it]
100%|██████████| 13/13 [00:38<00:00,  2.99s/it]


epoch: 10, f1_test: 0.11742268650228826, accuracy_test: 0.1575, f1_train: 0.16995906116768061,  accuracy_train: 0.20875


100%|██████████| 13/13 [02:19<00:00, 10.76s/it]
100%|██████████| 4/4 [00:11<00:00,  2.82s/it]
100%|██████████| 13/13 [00:38<00:00,  2.97s/it]


epoch: 11, f1_test: 0.10287117228223303, accuracy_test: 0.15, f1_train: 0.1993950383252012,  accuracy_train: 0.24125


100%|██████████| 13/13 [02:00<00:00,  9.25s/it]
100%|██████████| 4/4 [00:08<00:00,  2.11s/it]
100%|██████████| 13/13 [00:27<00:00,  2.15s/it]


epoch: 12, f1_test: 0.14096402064304425, accuracy_test: 0.18, f1_train: 0.22425938880983298,  accuracy_train: 0.265


100%|██████████| 13/13 [01:47<00:00,  8.28s/it]
100%|██████████| 4/4 [00:07<00:00,  1.97s/it]
100%|██████████| 13/13 [00:25<00:00,  1.96s/it]


epoch: 13, f1_test: 0.15194204163486702, accuracy_test: 0.1775, f1_train: 0.2851094806281918,  accuracy_train: 0.31125


100%|██████████| 13/13 [01:42<00:00,  7.85s/it]
100%|██████████| 4/4 [00:07<00:00,  1.89s/it]
100%|██████████| 13/13 [00:25<00:00,  1.95s/it]


epoch: 14, f1_test: 0.12344295084785681, accuracy_test: 0.17, f1_train: 0.2634154477574576,  accuracy_train: 0.296875


100%|██████████| 13/13 [02:00<00:00,  9.30s/it]
100%|██████████| 4/4 [00:07<00:00,  1.89s/it]
100%|██████████| 13/13 [00:27<00:00,  2.14s/it]


epoch: 15, f1_test: 0.16182083173936776, accuracy_test: 0.185, f1_train: 0.33296290225599007,  accuracy_train: 0.35125


100%|██████████| 13/13 [02:05<00:00,  9.69s/it]
100%|██████████| 4/4 [00:09<00:00,  2.27s/it]
100%|██████████| 13/13 [00:29<00:00,  2.30s/it]


epoch: 16, f1_test: 0.1588809194972878, accuracy_test: 0.1825, f1_train: 0.3736369894636805,  accuracy_train: 0.3975


100%|██████████| 13/13 [02:01<00:00,  9.38s/it]
100%|██████████| 4/4 [00:09<00:00,  2.42s/it]
100%|██████████| 13/13 [00:31<00:00,  2.46s/it]


epoch: 17, f1_test: 0.1488575955632027, accuracy_test: 0.185, f1_train: 0.38089761188698384,  accuracy_train: 0.400625


100%|██████████| 13/13 [02:03<00:00,  9.53s/it]
100%|██████████| 4/4 [00:11<00:00,  2.83s/it]
100%|██████████| 13/13 [00:33<00:00,  2.57s/it]


epoch: 18, f1_test: 0.15968254834794093, accuracy_test: 0.195, f1_train: 0.39484455563930604,  accuracy_train: 0.4125


100%|██████████| 13/13 [02:04<00:00,  9.58s/it]
100%|██████████| 4/4 [00:09<00:00,  2.35s/it]
100%|██████████| 13/13 [00:35<00:00,  2.76s/it]


epoch: 19, f1_test: 0.19004483137706152, accuracy_test: 0.21, f1_train: 0.4213141440910501,  accuracy_train: 0.436875


100%|██████████| 13/13 [02:02<00:00,  9.45s/it]
100%|██████████| 4/4 [00:09<00:00,  2.35s/it]
100%|██████████| 13/13 [00:32<00:00,  2.47s/it]


epoch: 20, f1_test: 0.1783728309833046, accuracy_test: 0.1925, f1_train: 0.47077596757953516,  accuracy_train: 0.48375


100%|██████████| 13/13 [02:01<00:00,  9.38s/it]
100%|██████████| 4/4 [00:10<00:00,  2.73s/it]
100%|██████████| 13/13 [00:32<00:00,  2.48s/it]


epoch: 21, f1_test: 0.1814713652360711, accuracy_test: 0.215, f1_train: 0.4451472819445168,  accuracy_train: 0.468125


100%|██████████| 13/13 [02:09<00:00,  9.99s/it]
100%|██████████| 4/4 [00:10<00:00,  2.67s/it]
100%|██████████| 13/13 [00:30<00:00,  2.32s/it]


epoch: 22, f1_test: 0.22864244108145038, accuracy_test: 0.25, f1_train: 0.582753339814355,  accuracy_train: 0.596875


100%|██████████| 13/13 [02:07<00:00,  9.83s/it]
100%|██████████| 4/4 [00:09<00:00,  2.41s/it]
100%|██████████| 13/13 [00:32<00:00,  2.51s/it]


epoch: 23, f1_test: 0.21687663749043012, accuracy_test: 0.2425, f1_train: 0.5869565519053129,  accuracy_train: 0.60375


100%|██████████| 13/13 [02:06<00:00,  9.71s/it]
100%|██████████| 4/4 [00:10<00:00,  2.56s/it]
100%|██████████| 13/13 [00:32<00:00,  2.48s/it]


epoch: 24, f1_test: 0.22125403008838374, accuracy_test: 0.2475, f1_train: 0.5763082070267392,  accuracy_train: 0.59625


100%|██████████| 13/13 [02:09<00:00, 10.00s/it]
100%|██████████| 4/4 [00:09<00:00,  2.30s/it]
100%|██████████| 13/13 [00:31<00:00,  2.39s/it]


epoch: 25, f1_test: 0.22736661701871536, accuracy_test: 0.2525, f1_train: 0.6864620306434318,  accuracy_train: 0.6875


100%|██████████| 13/13 [02:13<00:00, 10.28s/it]
100%|██████████| 4/4 [00:09<00:00,  2.43s/it]
100%|██████████| 13/13 [00:33<00:00,  2.60s/it]


epoch: 26, f1_test: 0.23686431554435758, accuracy_test: 0.2625, f1_train: 0.6817024232440657,  accuracy_train: 0.685625


100%|██████████| 13/13 [02:07<00:00,  9.82s/it]
100%|██████████| 4/4 [00:07<00:00,  1.98s/it]
100%|██████████| 13/13 [00:26<00:00,  2.07s/it]


epoch: 27, f1_test: 0.22131438785165247, accuracy_test: 0.2425, f1_train: 0.7446864262074417,  accuracy_train: 0.746875


100%|██████████| 13/13 [01:54<00:00,  8.79s/it]
100%|██████████| 4/4 [00:07<00:00,  1.85s/it]
100%|██████████| 13/13 [00:25<00:00,  1.98s/it]


epoch: 28, f1_test: 0.25710208200886625, accuracy_test: 0.27, f1_train: 0.7057215508371408,  accuracy_train: 0.71375


 92%|█████████▏| 12/13 [02:03<00:10, 10.32s/it]


KeyboardInterrupt: 