In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchmetrics
import numpy as np
import shelve
import random
import cv2 as cv
import matplotlib.pyplot as plt

import importlib
import model_1
import utils
_ = importlib.reload(model_1)
_ = importlib.reload(utils)

In [2]:
DEVICE = 'cuda'
BATCH_SIZE = 10
SHELVE_PATH = 'data/processed-data/data-1/db'

## Data Preparation

In [3]:
ds = model_1.ShelveDataset(SHELVE_PATH)
dl = DataLoader(
    ds, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    collate_fn=model_1.ShelveDataset.collate_fn, 
    num_workers=8,
)

## Training Loop

In [4]:
model = model_1.AvgPoolingCnn().to(DEVICE)
# model = torch.load('models/SCS-model-dict.pt')

In [6]:
# 1) Set up loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)
optimizer = torch.optim.AdamW([
	{'params': model.conv.parameters(), "lr": 0.0001},
	{'params': model.fc1.parameters(), "lr": 0.0001},
	{'params': model.out_v1.parameters(), "lr": 0.0001},
	{'params': model.out_v2.parameters(), "lr": 0.0001},
	{'params': model.out_v3.parameters(), "lr": 0.0001},
	{'params': model.out_v4.parameters(), "lr": 0.0001},
	{'params': model.out_v5.parameters(), "lr": 0.0001},
])

lr = 5e-6
optimizers = [
	torch.optim.AdamW(model.out_v1.parameters(), lr=lr),
	torch.optim.AdamW(model.out_v2.parameters(), lr=lr),
	torch.optim.AdamW(model.out_v3.parameters(), lr=lr),
	torch.optim.AdamW(model.out_v4.parameters(), lr=lr),
	torch.optim.AdamW(model.out_v5.parameters(), lr=lr),
	torch.optim.AdamW([
		{'params': model.conv.parameters(), "lr": lr},
		{'params': model.fc1.parameters(), "lr": lr},
	]),
]

report_interval = 12
accuracy_hist = utils.Accuracy()
loss_hist = []
condition = 'RNFN'

model.train()


epochs = 100
for epoch in range(epochs):
	optimizer.zero_grad()
	for optimizer in optimizers:
		optimizer.zero_grad()

	for batch_idx, (X_train, y_train) in enumerate(dl):
		for k, v in X_train.items():
			X_train[k] = v.to(DEVICE)
		for k, v in y_train.items():
			y_train[k] = v.to(DEVICE)

		# 2) Forward Propagation
		y_pred = model(X_train)
		# break

		# 3) loss calculation
		y_train_labeled = y_train[condition].argmax(dim=2)
		loss = 0
		for i in range(5):
			loss += criterion(y_pred[:, i, :], y_train_labeled[:, i])
				
		loss.backward()
		for opt in optimizers:
			opt.step()

		acc = accuracy_hist.calc_accuracy(y_pred, y_train[condition], epoch)
		loss_hist.append(loss.item())

		if batch_idx > 0 and batch_idx % report_interval == 0:
			curr_loss = sum(loss_hist[-10:]) / report_interval
			print(f"[Epoch {f'{epoch+1}/{epochs}':<7}][Batch {batch_idx:<3}]\tLoss: {loss:.4f}\tAccuracy: {acc} [{acc.mean():.4f}]")
			torch.cuda.empty_cache()



[Epoch 1/100  ][Batch 12 ]	Loss: 1.6833	Accuracy: [0.95384615 0.93076923 0.81538462 0.68461538 0.75384615] [0.8277]
[Epoch 1/100  ][Batch 24 ]	Loss: 1.6922	Accuracy: [0.956 0.928 0.816 0.656 0.76 ] [0.8232]
[Epoch 1/100  ][Batch 36 ]	Loss: 2.0276	Accuracy: [0.95945946 0.92972973 0.81351351 0.66216216 0.74594595] [0.8222]
[Epoch 1/100  ][Batch 48 ]	Loss: 1.7320	Accuracy: [0.96326531 0.93469388 0.81836735 0.68979592 0.74897959] [0.8310]
[Epoch 1/100  ][Batch 60 ]	Loss: 1.6094	Accuracy: [0.96885246 0.94098361 0.82295082 0.70491803 0.74262295] [0.8361]
[Epoch 1/100  ][Batch 72 ]	Loss: 1.4325	Accuracy: [0.96849315 0.94794521 0.8260274  0.72191781 0.75616438] [0.8441]
[Epoch 1/100  ][Batch 84 ]	Loss: 1.1041	Accuracy: [0.96941176 0.94235294 0.83647059 0.73411765 0.75882353] [0.8482]
[Epoch 1/100  ][Batch 96 ]	Loss: 1.7570	Accuracy: [0.97319588 0.94742268 0.84536082 0.73092784 0.76701031] [0.8528]
[Epoch 1/100  ][Batch 108]	Loss: 1.9675	Accuracy: [0.97431193 0.9440367  0.84311927 0.73119266 0.

KeyboardInterrupt: 

In [7]:
# Save Model
torch.save(model, f'models/{condition}-model-dict.pt')

## Evaluation

In [9]:
model = torch.load('models/SCS-model-dict.pt').to(DEVICE)

  model = torch.load('models/SCS-model-dict.pt').to(DEVICE)


In [10]:
model.eval()
with torch.no_grad():
	for X_train, y_train in dl:
		for k, v in X_train.items():
			X_train[k] = v.to(DEVICE)
		for k, v in y_train.items():
			y_train[k] = v.to(DEVICE)
		
		y_pred = model(X_train)
		y_pred = torch.softmax(y_pred, dim=2)

		# print(f"{y_pred.size() = }")
		# print(f"{y_train['SCS'].size() = }\n")


		acc = torchmetrics.functional.accuracy(y_pred[:, :, :], y_train["SCS"][:, :, :], task='multiclass', num_classes=3)
		print(f"Metric = {acc.item()}\n")

		f1 = torchmetrics.functional.f1_score(y_pred[:, :, :], y_train["SCS"][:, :, :], task='multiclass', num_classes=3)
		print(f"Metric = {f1.item()}\n")
		

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.0

Metric = 0.0

Metric = 0.0

Metric = 0.0

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.01666666753590107

Metric = 0.01666666753590107

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.06666667014360428

Metric = 0.06666667014360428

Metric = 0.03333333507180214

Metric = 0.03333333507180214

Metric = 0.05000000074505806

Metric = 0.05000000074505806

Metric = 0.03333333507180214

Metric = 0.03333333507180214

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.008333333767950535

Metric = 0.01666666753590107

Metric = 0.01666666753590107

Metric = 0.05000000074505806

Metric = 0.05000000074505806

Metric = 0.00833333376795053

KeyboardInterrupt: 