In [1]:
import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

import torch
import torch.nn as nn
from model_helper import MobileNetV2, Inv2d
import wandb
import mysql.connector as connector
from pathlib import Path
import tqdm
import torchvision
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.models import mobilenet_v2


In [2]:
home = os.path.expanduser('~')

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device {device}')

Using device cuda


In [4]:
# research_dir = Path(home, 'Desktop', 'Education', 'Spring 2025', 'AI', 'research')
research_dir = Path(home, "ai_research_proj_spring_2025", "research_ai_class_spring_2025") # in lab 409 computer for Agafia
os.chdir(research_dir)

from data_helper import SQLDataset_Informative

os.chdir(home)

In [5]:
from torchvision.transforms import v2

# transforms
transformations = v2.Compose([
    v2.RandomResizedCrop(size=(224, 224), antialias=True), 
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [6]:
home = os.path.expanduser('~')
os.chdir(home) # b/c we will be using universal paths

host = '127.0.0.1'
user = 'root' # change to your username
password = 'vasya1' # change to your password
database = 'ai_proj_2025' # we should all have this as the db name 

try:
    conn = connector.connect(
        host = host, 
        user = user, 
        password = password, 
        database = database
    )
    print('success')
except connector.Error as err:
    print(err)

success


In [7]:
# create train, val, and test sets

data_dir=Path(home, 'OneDrive - Stephen F. Austin State University', 'CrisisMMD_v2.0','CrisisMMD_v2.0')

train_set = SQLDataset_Informative(conn=conn, img_col='image_path', label_col='image_info', transform=transformations, 
                     data_dir=data_dir, is_train=True)
val_set = SQLDataset_Informative(conn=conn, img_col='image_path', label_col='image_info', transform=transformations, 
                     data_dir=data_dir, is_val=True)
test_set = SQLDataset_Informative(conn=conn, img_col='image_path', label_col='image_info', transform=transformations, 
                     data_dir=data_dir, is_test=True)

In [8]:
train_set_2points = [train_set.__getitem__(i) for i in range(2)]
val_set_2points = [val_set.__getitem__(i) for i in range(2)]

In [9]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set, batch_size=256)
val_loader = DataLoader(val_set, batch_size=128)
test_loader = DataLoader(test_set, batch_size=128)

# for data in train_loader:
#     print(data['label'])

In [10]:
#Define the convnet (Eventually replae with mobile 2)
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3) #specify the size (outer numbers) and amount (middle number) of filters
        self.pool = nn.MaxPool2d(2, 2) #specify pool size first number is size of pool, second is step size
        self.conv2 = nn.Conv2d(16, 8, 3) #new depth is amount of filters in previous conv layer
        self.fc1 = nn.Linear(54*54*8, 120)
        self.fc2 = nn.Linear(120, 60)
        self.fc3 = nn.Linear(60, 2) #finial fc layer needs 19 outputs because we have 19 layers # ???

    def forward(self, x):

        x = F.relu(self.conv1(x))
     
        x = self.pool(x)
       
        x = F.relu(self.conv2(x))

        x = self.pool(x)

        x = x.view(-1, 54*54*8) # flatten

        x = F.relu(self.fc1(x))    #fully connected, relu         
        x = F.relu(self.fc2(x))    
       
        x = self.fc3(x)     #output    
        return x

In [11]:
# importing accuracy metric functions
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

**IMPORTANT**: Make sure to change the run_name (and 'architecture' parameter of the wandb `run` variable if necessary) with each new run. 

In [12]:
# validation fn
from collections import defaultdict 

def dev(model, val_loader):
    model.to(device)
    batch_size = val_loader.batch_size
    avg = 'macro' # used when computing certain accuracy metrics
    model.eval()

    eval_loss = 0

    all_preds = []
    all_trues = []

    with torch.no_grad():
        for b, batch in tqdm.tqdm(enumerate(val_loader), 
                             total= len(val_loader), desc=f"Processing validation data"):
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            raw_logits = model.forward(images)

            preds = torch.argmax(raw_logits, dim=1) # https://discuss.pytorch.org/t/cross-entropy-loss-get-predicted-class/58215

            loss = nn.CrossEntropyLoss()(raw_logits, labels)

            eval_loss += loss.item()

            all_preds.extend(preds.tolist())
            all_trues.extend(labels.tolist())


        # metrics 
        acc_total = accuracy_score(y_true=all_trues, y_pred=all_preds)
        precision = precision_score(y_true=all_trues, y_pred=all_preds, zero_division=0, average=avg)
        recall = recall_score(y_true=all_trues, y_pred=all_preds, zero_division=0, average=avg)
        f1 = f1_score(y_true=all_trues, y_pred=all_preds, zero_division=0, average=avg)

        avg_eval_loss = eval_loss / (len(val_loader))

        metrics = {
            'accuracy': acc_total, 
            'precision': precision, 
            'recall': recall, 
            'f1': f1, 
            'avg_eval_loss': avg_eval_loss
        }
        wandb.log(metrics)
        print('****Evaluation****')
        print(f'total_accuracy: {acc_total}')

        return acc_total, zip(all_preds, all_trues)
    


In [13]:
def train_eval(model, num_epochs, run_name, lr, architecture, frozen_layers, dataset='CrisisMMD'):
    # training hyperparameters & functions/tools
    lr = lr 
    num_epochs = num_epochs
    run_name = run_name

    misclassified = defaultdict(int)
    

    best_val_acc = 0.0
    optimizer = torch.optim.Adam(model.parameters(), lr=lr) #stocastic gradient descent for our optimization algorithm
    lr_sched = MultiStepLR(optimizer=optimizer, milestones=list(range(50, num_epochs, 30)), gamma=.1)

    model.to(device)
    # for saving the models
    Path(research_dir, 'models' ).mkdir(parents=True, exist_ok=True)

    # before training, set up wandb for tracking purposes
    os.environ["WANDB_API_KEY"] = "5a08d1ebbf0e86ab877a128b98be3c320301b6a0"

    run = wandb.init(
        # Set the wandb entity where your project will be logged (generally your team name).
        entity="agafiabschool-stephen-f-austin-state-university",
        # Set the wandb project where this run will be logged.
        project="Research Project for CSCI-1465",
        # Track hyperparameters and run metadata.
        config={
            "learning_rate": lr,
            "architecture": architecture,
            "dataset": "CrisisMMD",
            "epochs": num_epochs,
            'frozen_layers': frozen_layers,
        }, name=run_name
    )


    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}')
        wandb.log({'epoch': epoch+1})

        loss = 0

        for b, batch in tqdm.tqdm(enumerate(train_loader), 
                            total= len(train_loader), desc=f"Processing training data in epoch {epoch+1}"):
            model.train()
            images = batch['image'].to(device)
            labels = batch['label'].to(device)

            model.zero_grad() 
            optimizer.zero_grad()

            # forward pass
            raw_logits = model.forward(images)
            # loss - can use raw logits for this function b/c it applies LogSoftmax 
            loss = nn.CrossEntropyLoss()(raw_logits, labels)
            print(f'Train Loss: {loss}')
            wandb.log({'Train Loss': loss.item()})
            wandb.log({'LR': lr_sched.get_last_lr()[0]})


            # backprop!
            loss.backward()

            optimizer.step()

            if (b+1) % 20 == 0:
                print(f'batch: {b+1} ; loss: {loss.item()}')
        
        # each epoch, run validation
        acc, zipped = dev(model=model, val_loader=val_loader)

        for i, (pred, true) in enumerate(zipped):
            if true != pred: 
                misclassified[i] += 1
        
        if acc > best_val_acc:
            best_val_acc = acc
            torch.save(model, Path(research_dir, 'models', f'{run_name}'))
        
        lr_sched.step()
    
    return best_val_acc, misclassified

In [None]:
# time to train
import time
start = time.time()

num_epochs = 10
os.chdir(home)

# for grid-search
history = []
# for lr in [10**-4, 30**-4, 10**-3, 30**-3, 10**-2, 30**-2, 10**-1, 30**-1, 1]:
for lr in [10**-3]:
    # run_name = f'MobileNetV2 lr={lr}'
    run_name = f'Basic CNN test -- lr={lr}'
    my_dict = {}
    my_dict['lr'] = lr

    # instantiate our model
    model = ConvNet()

        
    acc, misclassified = train_eval(model, num_epochs, run_name, lr=lr, architecture='Basic CNN', frozen_layers=0, dataset='CrisisMMD')
    my_dict['acc'] = acc
    my_dict['missed'] = misclassified
    history.append(my_dict)
end = time.time()

wandb: Currently logged in as: agafiabschool (agafiabschool-stephen-f-austin-state-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Epoch 1


Processing training data in epoch 1:   0%|          | 0/64 [00:00<?, ?it/s]

Train Loss: 0.685461163520813


Processing training data in epoch 1:   2%|▏         | 1/64 [00:01<01:16,  1.21s/it]

Train Loss: 3.1263749599456787


Processing training data in epoch 1:   3%|▎         | 2/64 [00:02<01:07,  1.10s/it]

Train Loss: 1.841259241104126


Processing training data in epoch 1:   5%|▍         | 3/64 [00:03<01:03,  1.05s/it]

Train Loss: 0.6574533581733704


Processing training data in epoch 1:   6%|▋         | 4/64 [00:04<00:58,  1.03it/s]

Train Loss: 0.713516891002655


Processing training data in epoch 1:   8%|▊         | 5/64 [00:04<00:54,  1.09it/s]

Train Loss: 0.7470666170120239


Processing training data in epoch 1:   9%|▉         | 6/64 [00:05<00:51,  1.12it/s]

Train Loss: 0.7539517283439636


Processing training data in epoch 1:  11%|█         | 7/64 [00:06<00:54,  1.04it/s]

Train Loss: 0.6862316727638245


Processing training data in epoch 1:  12%|█▎        | 8/64 [00:07<00:52,  1.07it/s]

Train Loss: 0.674427330493927


Processing training data in epoch 1:  14%|█▍        | 9/64 [00:08<00:49,  1.10it/s]

Train Loss: 0.6794255375862122


Processing training data in epoch 1:  16%|█▌        | 10/64 [00:09<00:48,  1.12it/s]

Train Loss: 0.694210410118103


Processing training data in epoch 1:  17%|█▋        | 11/64 [00:10<00:46,  1.13it/s]

Train Loss: 0.718181848526001


Processing training data in epoch 1:  19%|█▉        | 12/64 [00:11<00:45,  1.15it/s]

Train Loss: 0.712857723236084


Processing training data in epoch 1:  20%|██        | 13/64 [00:12<00:44,  1.14it/s]

Train Loss: 0.7003750801086426


Processing training data in epoch 1:  22%|██▏       | 14/64 [00:12<00:43,  1.15it/s]

Train Loss: 0.6950231790542603


Processing training data in epoch 1:  23%|██▎       | 15/64 [00:13<00:42,  1.14it/s]

Train Loss: 0.7129763960838318


Processing training data in epoch 1:  25%|██▌       | 16/64 [00:14<00:42,  1.13it/s]

Train Loss: 0.7091782689094543


Processing training data in epoch 1:  27%|██▋       | 17/64 [00:15<00:41,  1.13it/s]

Train Loss: 0.6899487972259521


Processing training data in epoch 1:  28%|██▊       | 18/64 [00:16<00:40,  1.15it/s]

Train Loss: 0.681421160697937


Processing training data in epoch 1:  30%|██▉       | 19/64 [00:17<00:39,  1.14it/s]

Train Loss: 0.6849035620689392
batch: 20 ; loss: 0.6849035620689392


Processing training data in epoch 1:  31%|███▏      | 20/64 [00:18<00:40,  1.10it/s]

Train Loss: 0.6821866631507874


Processing training data in epoch 1:  33%|███▎      | 21/64 [00:19<00:39,  1.10it/s]

Train Loss: 0.7256664037704468


Processing training data in epoch 1:  34%|███▍      | 22/64 [00:20<00:37,  1.13it/s]

Train Loss: 0.7083988785743713


Processing training data in epoch 1:  36%|███▌      | 23/64 [00:20<00:35,  1.16it/s]

Train Loss: 0.6886200904846191


Processing training data in epoch 1:  38%|███▊      | 24/64 [00:21<00:34,  1.16it/s]

Train Loss: 0.6926554441452026


Processing training data in epoch 1:  39%|███▉      | 25/64 [00:22<00:34,  1.12it/s]

Train Loss: 0.7106735706329346


Processing training data in epoch 1:  41%|████      | 26/64 [00:23<00:35,  1.08it/s]

Train Loss: 0.7117250561714172


Processing training data in epoch 1:  42%|████▏     | 27/64 [00:24<00:34,  1.08it/s]

Train Loss: 0.7184265851974487


Processing training data in epoch 1:  44%|████▍     | 28/64 [00:25<00:33,  1.08it/s]

Train Loss: 0.7273117303848267


Processing training data in epoch 1:  45%|████▌     | 29/64 [00:26<00:32,  1.08it/s]

Train Loss: 0.7185964584350586


Processing training data in epoch 1:  47%|████▋     | 30/64 [00:27<00:31,  1.09it/s]

Train Loss: 0.7185862064361572


Processing training data in epoch 1:  48%|████▊     | 31/64 [00:28<00:30,  1.09it/s]

Train Loss: 0.6959285736083984


Processing training data in epoch 1:  50%|█████     | 32/64 [00:29<00:30,  1.06it/s]

Train Loss: 0.692748486995697


Processing training data in epoch 1:  52%|█████▏    | 33/64 [00:30<00:29,  1.05it/s]

Train Loss: 0.6889830827713013


Processing training data in epoch 1:  53%|█████▎    | 34/64 [00:31<00:28,  1.04it/s]

Train Loss: 0.6827639937400818


Processing training data in epoch 1:  55%|█████▍    | 35/64 [00:32<00:28,  1.02it/s]

Train Loss: 0.6859844923019409


Processing training data in epoch 1:  56%|█████▋    | 36/64 [00:33<00:27,  1.01it/s]

Train Loss: 0.6790741682052612


Processing training data in epoch 1:  58%|█████▊    | 37/64 [00:34<00:26,  1.03it/s]

Train Loss: 0.762104332447052


Processing training data in epoch 1:  59%|█████▉    | 38/64 [00:35<00:24,  1.07it/s]

Train Loss: 0.7488265037536621


Processing training data in epoch 1:  61%|██████    | 39/64 [00:35<00:22,  1.10it/s]

Train Loss: 0.7500255107879639
batch: 40 ; loss: 0.7500255107879639


Processing training data in epoch 1:  62%|██████▎   | 40/64 [00:36<00:22,  1.08it/s]

Train Loss: 0.7408356070518494


Processing training data in epoch 1:  64%|██████▍   | 41/64 [00:37<00:21,  1.09it/s]

Train Loss: 0.7126501798629761


Processing training data in epoch 1:  66%|██████▌   | 42/64 [00:38<00:19,  1.11it/s]

Train Loss: 0.7002103328704834


Processing training data in epoch 1:  67%|██████▋   | 43/64 [00:39<00:18,  1.12it/s]

Train Loss: 0.694526731967926


Processing training data in epoch 1:  69%|██████▉   | 44/64 [00:40<00:17,  1.12it/s]

Train Loss: 0.6907820105552673


Processing training data in epoch 1:  70%|███████   | 45/64 [00:41<00:17,  1.11it/s]

Train Loss: 0.6943063735961914


Processing training data in epoch 1:  72%|███████▏  | 46/64 [00:42<00:16,  1.12it/s]

Train Loss: 0.6990026235580444


Processing training data in epoch 1:  73%|███████▎  | 47/64 [00:43<00:15,  1.12it/s]

Train Loss: 0.6975648999214172


Processing training data in epoch 1:  75%|███████▌  | 48/64 [00:43<00:14,  1.13it/s]

Train Loss: 0.700809895992279


Processing training data in epoch 1:  77%|███████▋  | 49/64 [00:44<00:13,  1.14it/s]

Train Loss: 0.6992950439453125


Processing training data in epoch 1:  78%|███████▊  | 50/64 [00:45<00:12,  1.14it/s]

Train Loss: 0.7023670673370361


Processing training data in epoch 1:  80%|███████▉  | 51/64 [00:46<00:11,  1.14it/s]

Train Loss: 0.7037017345428467


Processing training data in epoch 1:  81%|████████▏ | 52/64 [00:47<00:10,  1.14it/s]

Train Loss: 0.6943080425262451


Processing training data in epoch 1:  83%|████████▎ | 53/64 [00:48<00:09,  1.14it/s]

Train Loss: 0.6881586313247681


Processing training data in epoch 1:  84%|████████▍ | 54/64 [00:49<00:08,  1.17it/s]

Train Loss: 0.6882871389389038


Processing training data in epoch 1:  86%|████████▌ | 55/64 [00:49<00:07,  1.17it/s]

Train Loss: 0.6885460615158081


Processing training data in epoch 1:  88%|████████▊ | 56/64 [00:50<00:06,  1.15it/s]

Train Loss: 0.6876941323280334


Processing training data in epoch 1:  89%|████████▉ | 57/64 [00:51<00:06,  1.16it/s]

Train Loss: 0.6876626014709473


Processing training data in epoch 1:  91%|█████████ | 58/64 [00:52<00:05,  1.18it/s]

Train Loss: 0.692563533782959


Processing training data in epoch 1:  92%|█████████▏| 59/64 [00:53<00:04,  1.19it/s]

Train Loss: 0.6915305852890015
batch: 60 ; loss: 0.6915305852890015


Processing training data in epoch 1:  94%|█████████▍| 60/64 [00:54<00:03,  1.18it/s]

Train Loss: 0.689386248588562


Processing training data in epoch 1:  95%|█████████▌| 61/64 [00:55<00:02,  1.19it/s]

Train Loss: 0.6981763243675232


Processing training data in epoch 1:  97%|█████████▋| 62/64 [00:55<00:01,  1.20it/s]

Train Loss: 0.7030037045478821


Processing training data in epoch 1:  98%|█████████▊| 63/64 [00:56<00:00,  1.18it/s]

Train Loss: 0.7077253460884094


Processing training data in epoch 1: 100%|██████████| 64/64 [00:57<00:00,  1.12it/s]
Processing validation data: 100%|██████████| 8/8 [00:03<00:00,  2.45it/s]


****Evaluation****
total_accuracy: 0.5542035398230089
Epoch 2


Processing training data in epoch 2:   0%|          | 0/64 [00:00<?, ?it/s]

Train Loss: 0.6966164112091064


Processing training data in epoch 2:   2%|▏         | 1/64 [00:00<01:01,  1.03it/s]

Train Loss: 0.6951183676719666


Processing training data in epoch 2:   3%|▎         | 2/64 [00:01<00:56,  1.10it/s]

Train Loss: 0.694015383720398


Processing training data in epoch 2:   5%|▍         | 3/64 [00:02<00:51,  1.18it/s]

Train Loss: 0.6946074962615967


Processing training data in epoch 2:   6%|▋         | 4/64 [00:03<00:50,  1.19it/s]

Train Loss: 0.694968044757843


Processing training data in epoch 2:   8%|▊         | 5/64 [00:04<00:48,  1.21it/s]

Train Loss: 0.693414568901062


Processing training data in epoch 2:   9%|▉         | 6/64 [00:05<00:48,  1.21it/s]

Train Loss: 0.6944633722305298


Processing training data in epoch 2:  11%|█         | 7/64 [00:05<00:48,  1.19it/s]

Train Loss: 0.6923880577087402


Processing training data in epoch 2:  12%|█▎        | 8/64 [00:06<00:47,  1.18it/s]

Train Loss: 0.6900097131729126


Processing training data in epoch 2:  14%|█▍        | 9/64 [00:07<00:46,  1.19it/s]

Train Loss: 0.6893888711929321


Processing training data in epoch 2:  16%|█▌        | 10/64 [00:08<00:44,  1.20it/s]

Train Loss: 0.6909520030021667


Processing training data in epoch 2:  17%|█▋        | 11/64 [00:09<00:44,  1.20it/s]

Train Loss: 0.6899014711380005


Processing training data in epoch 2:  19%|█▉        | 12/64 [00:10<00:42,  1.21it/s]

Train Loss: 0.6899839639663696


Processing training data in epoch 2:  20%|██        | 13/64 [00:10<00:42,  1.19it/s]

Train Loss: 0.692211925983429


Processing training data in epoch 2:  22%|██▏       | 14/64 [00:11<00:42,  1.19it/s]

Train Loss: 0.6961523294448853


Processing training data in epoch 2:  23%|██▎       | 15/64 [00:12<00:41,  1.18it/s]

Train Loss: 0.6864972114562988


Processing training data in epoch 2:  25%|██▌       | 16/64 [00:13<00:41,  1.16it/s]

Train Loss: 0.6892929077148438


Processing training data in epoch 2:  27%|██▋       | 17/64 [00:14<00:40,  1.16it/s]

Train Loss: 0.6895900368690491


Processing training data in epoch 2:  28%|██▊       | 18/64 [00:15<00:38,  1.19it/s]

Train Loss: 0.6921069025993347


Processing training data in epoch 2:  30%|██▉       | 19/64 [00:16<00:37,  1.19it/s]

Train Loss: 0.695046603679657
batch: 20 ; loss: 0.695046603679657


Processing training data in epoch 2:  31%|███▏      | 20/64 [00:17<00:38,  1.13it/s]

Train Loss: 0.6947031021118164


Processing training data in epoch 2:  33%|███▎      | 21/64 [00:17<00:37,  1.13it/s]

Train Loss: 0.6847718358039856


Processing training data in epoch 2:  34%|███▍      | 22/64 [00:18<00:35,  1.17it/s]

Train Loss: 0.677966833114624


Processing training data in epoch 2:  36%|███▌      | 23/64 [00:19<00:33,  1.21it/s]

Train Loss: 0.6867771148681641


Processing training data in epoch 2:  38%|███▊      | 24/64 [00:20<00:33,  1.20it/s]

Train Loss: 0.6884025931358337


Processing training data in epoch 2:  39%|███▉      | 25/64 [00:21<00:33,  1.18it/s]

Train Loss: 0.6910014152526855


Processing training data in epoch 2:  41%|████      | 26/64 [00:22<00:33,  1.15it/s]

Train Loss: 0.6929228901863098


Processing training data in epoch 2:  42%|████▏     | 27/64 [00:22<00:32,  1.15it/s]

Train Loss: 0.6930840015411377


Processing training data in epoch 2:  44%|████▍     | 28/64 [00:23<00:31,  1.16it/s]

Train Loss: 0.6979994773864746


Processing training data in epoch 2:  45%|████▌     | 29/64 [00:24<00:30,  1.15it/s]

Train Loss: 0.6966859102249146


Processing training data in epoch 2:  47%|████▋     | 30/64 [00:25<00:29,  1.15it/s]

Train Loss: 0.7018151879310608


Processing training data in epoch 2:  48%|████▊     | 31/64 [00:26<00:28,  1.14it/s]

Train Loss: 0.6921390891075134


Processing training data in epoch 2:  50%|█████     | 32/64 [00:27<00:28,  1.13it/s]

Train Loss: 0.6955468058586121


Processing training data in epoch 2:  52%|█████▏    | 33/64 [00:28<00:27,  1.14it/s]

Train Loss: 0.6937297582626343


Processing training data in epoch 2:  53%|█████▎    | 34/64 [00:29<00:25,  1.16it/s]

Train Loss: 0.6977639198303223


Processing training data in epoch 2:  55%|█████▍    | 35/64 [00:29<00:25,  1.14it/s]

Train Loss: 0.7006604671478271


Processing training data in epoch 2:  56%|█████▋    | 36/64 [00:30<00:24,  1.14it/s]

Train Loss: 0.6980662941932678


Processing training data in epoch 2:  58%|█████▊    | 37/64 [00:31<00:23,  1.14it/s]

Train Loss: 0.668484091758728


Processing training data in epoch 2:  59%|█████▉    | 38/64 [00:32<00:22,  1.18it/s]

Train Loss: 0.6679157614707947


Processing training data in epoch 2:  61%|██████    | 39/64 [00:33<00:20,  1.19it/s]

Train Loss: 0.6649847626686096
batch: 40 ; loss: 0.6649847626686096


Processing training data in epoch 2:  62%|██████▎   | 40/64 [00:34<00:20,  1.17it/s]

Train Loss: 0.66912841796875


Processing training data in epoch 2:  64%|██████▍   | 41/64 [00:35<00:19,  1.17it/s]

Train Loss: 0.6645679473876953


Processing training data in epoch 2:  66%|██████▌   | 42/64 [00:35<00:18,  1.18it/s]

Train Loss: 0.6703693866729736


Processing training data in epoch 2:  67%|██████▋   | 43/64 [00:36<00:17,  1.18it/s]

Train Loss: 0.6764172315597534


Processing training data in epoch 2:  69%|██████▉   | 44/64 [00:37<00:16,  1.19it/s]

Train Loss: 0.7121334075927734


Processing training data in epoch 2:  70%|███████   | 45/64 [00:38<00:16,  1.18it/s]

Train Loss: 0.7163678407669067


Processing training data in epoch 2:  72%|███████▏  | 46/64 [00:39<00:15,  1.19it/s]

Train Loss: 0.7296193838119507


Processing training data in epoch 2:  73%|███████▎  | 47/64 [00:40<00:14,  1.18it/s]

Train Loss: 0.72037273645401


Processing training data in epoch 2:  75%|███████▌  | 48/64 [00:40<00:13,  1.18it/s]

Train Loss: 0.7145230174064636


Processing training data in epoch 2:  77%|███████▋  | 49/64 [00:41<00:12,  1.19it/s]

Train Loss: 0.7030789256095886


Processing training data in epoch 2:  78%|███████▊  | 50/64 [00:42<00:11,  1.19it/s]

Train Loss: 0.7102729678153992


Processing training data in epoch 2:  80%|███████▉  | 51/64 [00:43<00:10,  1.20it/s]

Train Loss: 0.7044230103492737


Processing training data in epoch 2:  81%|████████▏ | 52/64 [00:44<00:10,  1.20it/s]

Train Loss: 0.6959487199783325


Processing training data in epoch 2:  83%|████████▎ | 53/64 [00:45<00:09,  1.20it/s]

Train Loss: 0.6875272989273071


Processing training data in epoch 2:  84%|████████▍ | 54/64 [00:45<00:08,  1.22it/s]

Train Loss: 0.6891255378723145


Processing training data in epoch 2:  86%|████████▌ | 55/64 [00:46<00:07,  1.24it/s]

Train Loss: 0.6917039752006531


Processing training data in epoch 2:  88%|████████▊ | 56/64 [00:47<00:06,  1.24it/s]

Train Loss: 0.690819263458252


Processing training data in epoch 2:  89%|████████▉ | 57/64 [00:48<00:05,  1.25it/s]

Train Loss: 0.6898791193962097


Processing training data in epoch 2:  91%|█████████ | 58/64 [00:49<00:04,  1.23it/s]

Train Loss: 0.6912925839424133


Processing training data in epoch 2:  92%|█████████▏| 59/64 [00:49<00:04,  1.23it/s]

Train Loss: 0.6885873675346375
batch: 60 ; loss: 0.6885873675346375


Processing training data in epoch 2:  94%|█████████▍| 60/64 [00:50<00:03,  1.20it/s]

Train Loss: 0.6873135566711426


Processing training data in epoch 2:  95%|█████████▌| 61/64 [00:51<00:02,  1.22it/s]

Train Loss: 0.690718412399292


Processing training data in epoch 2:  97%|█████████▋| 62/64 [00:52<00:01,  1.23it/s]

Train Loss: 0.6916154623031616


Processing training data in epoch 2:  98%|█████████▊| 63/64 [00:53<00:00,  1.20it/s]

Train Loss: 0.6904694437980652


Processing training data in epoch 2: 100%|██████████| 64/64 [00:53<00:00,  1.19it/s]
Processing validation data: 100%|██████████| 8/8 [00:03<00:00,  2.66it/s]


****Evaluation****
total_accuracy: 0.5420353982300885
Epoch 3


Processing training data in epoch 3:   0%|          | 0/64 [00:00<?, ?it/s]

Train Loss: 0.707499086856842


Processing training data in epoch 3:   2%|▏         | 1/64 [00:00<01:00,  1.04it/s]

Train Loss: 0.7049330472946167


Processing training data in epoch 3:   3%|▎         | 2/64 [00:01<00:53,  1.17it/s]

Train Loss: 0.699703574180603


Processing training data in epoch 3:   5%|▍         | 3/64 [00:02<00:50,  1.21it/s]

Train Loss: 0.6999161839485168


Processing training data in epoch 3:   6%|▋         | 4/64 [00:03<00:49,  1.21it/s]

Train Loss: 0.6934714317321777


Processing training data in epoch 3:   8%|▊         | 5/64 [00:04<00:48,  1.22it/s]

In [None]:
total_time = end - start

In [None]:
my_dict['missed']

defaultdict(int,
            {0: 1,
             1: 1,
             2: 1,
             3: 1,
             4: 1,
             5: 1,
             6: 1,
             9: 1,
             10: 1,
             12: 1,
             13: 1,
             14: 1,
             15: 1,
             16: 1,
             17: 1,
             18: 1,
             19: 1,
             20: 1,
             22: 1,
             23: 1,
             24: 1,
             25: 1,
             29: 1,
             30: 1,
             31: 1,
             32: 1,
             33: 1,
             34: 1,
             35: 1,
             37: 1,
             38: 1,
             39: 1,
             42: 1,
             43: 1,
             44: 1,
             45: 1,
             46: 1,
             47: 1,
             48: 1,
             50: 1,
             52: 1,
             53: 1,
             54: 1,
             56: 1,
             57: 1,
             59: 1,
             61: 1,
             63: 1,
             64: 1,
           