In [1]:
import os
import cv2
import numpy as np
import math
import mediapipe as mp
from matplotlib import pyplot as plt
import glob
from util.img2bone import HandDetector
import os
import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from PIL import Image
import os
import glob
import numpy as np
from tqdm.auto import tqdm
from loader.dataloader import SkeletonData


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def find_adjacency_matrix():
    mp_hands = mp.solutions.hands
    hands = mp_hands.Hands()
    adj = torch.zeros((21,21))
    for connection in mp_hands.HAND_CONNECTIONS:
        adj[connection[0],connection[1]] = 1
        adj[connection[1],connection[0]] = 1
    return adj 

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class GraphConvolution(nn.Module):
	def __init__(self, input_dim, output_dim, num_vetex, act=F.relu, dropout=0.5, bias=True):
		super(GraphConvolution, self).__init__()

		self.alpha = 1.

		self.act = act
		self.dropout = nn.Dropout(dropout)
		self.weight = nn.Parameter(torch.randn(input_dim, output_dim)).to(device)
		if bias:
			self.bias = nn.Parameter(torch.randn(output_dim)).to(device)
		else:
			self.bias = None

		for w in [self.weight]:
			nn.init.xavier_normal_(w)

	def normalize(self, m):
		rowsum = torch.sum(m, 0)
		r_inv = torch.pow(rowsum, -0.5)
		r_mat_inv = torch.diag(r_inv).float()

		m_norm = torch.mm(r_mat_inv, m)
		m_norm = torch.mm(m_norm, r_mat_inv)

		return m_norm

	def forward(self, adj, x):

		x = self.dropout(x)

		# K-ordered Chebyshev polynomial
		adj_norm = self.normalize(adj)
		sqr_norm = self.normalize(torch.mm(adj,adj))
		m_norm = (self.alpha*adj_norm + (1.-self.alpha)*sqr_norm).to(device)

		x_tmp = torch.einsum('abcd,de->abce', x, self.weight)
		x_out = torch.einsum('ij,abid->abjd', m_norm, x_tmp)
		if self.bias is not None:
			x_out += self.bias
		x_out = self.act(x_out)
  
		
		return x_out
		
		

class StandConvolution(nn.Module):
	def __init__(self, dims, num_classes, dropout):
		super(StandConvolution, self).__init__()

		self.dropout = nn.Dropout(dropout)
		self.conv = nn.Sequential(
								   nn.Conv2d(dims[0], dims[1], kernel_size=5, stride=2),
								   nn.InstanceNorm2d(dims[1]),
								   nn.ReLU(inplace=True),
								   #nn.AvgPool2d(3, stride=2),
								   nn.Conv2d(dims[1], dims[2], kernel_size=5, stride=2),
								   nn.InstanceNorm2d(dims[2]),
								   nn.ReLU(inplace=True),
								   #nn.AvgPool2d(3, stride=2),
								   nn.Conv2d(dims[2], dims[3], kernel_size=5, stride=2),
								   nn.InstanceNorm2d(dims[3]),
								   nn.ReLU(inplace=True),
								   #nn.AvgPool2d(3, stride=2)
								   ).to(device)

		self.fc = nn.Linear(dims[3]*5*5, num_classes).to(device)

	def forward(self, x):
		x = self.dropout(x.permute(0,3,1,2))# (1,9,62,63)
		x_tmp = self.conv(x)
		
		x_out = self.fc(x_tmp.view(x.size(0), -1))

		return x_out


In [4]:
class GGCN(nn.Module):
	def __init__(self, adj, num_classes, gc_dims, sc_dims, dropout=0.5):
		super(GGCN, self).__init__()

		
		adj = adj + torch.eye(adj.size(0)).to(adj).detach()
		ident = torch.eye(adj.size(0)).to(adj)
		zeros = torch.zeros(adj.size(0), adj.size(1)).to(adj)
		self.adj = torch.cat([torch.cat([adj, ident, zeros], 1),
							  torch.cat([ident, adj, ident], 1),
							  torch.cat([zeros, ident, adj], 1)], 0).float()
	
		self.gcl = GraphConvolution(gc_dims[0], gc_dims[1], 21, dropout=dropout)
		self.conv= StandConvolution(sc_dims, num_classes, dropout=dropout)

	def forward(self, x):
		# x: 
		x = torch.cat([x[:,:-2],x[:,1:-1],x[:,2:]],dim = 2)
		multi_conv = self.gcl(self.adj, x) 
		logit = self.conv(multi_conv) # (1,62,63,9)
		return logit

In [5]:
def train(train_loader,model,criterion,optimizer,device):
    running_loss = 0
    model.train()
   
    for images,labels in tqdm(train_loader):
        
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        outputs = model(images)
        loss = criterion(outputs,labels)
        running_loss += loss.item()
        
        #backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
    
    epoch_loss = running_loss / (len(train_loader))
    return model,epoch_loss,optimizer

def validate(valid_loader,model,criterion,device):
    model.eval()
    running_loss = 0
   
    
    for images,labels in tqdm(valid_loader):
       
        images = images.to(device)
        labels = labels.to(device)
        
        # forward
        outputs = model(images)
        loss = criterion(outputs,labels)
        running_loss += loss.item()

    epoch_loss = running_loss / (len(valid_loader))
    return model,epoch_loss

def get_accuracy(model,data_loader,device):
    correct = 0
    total = 0
    
    with torch.no_grad():
        model.eval()
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            predicted = torch.argmax(torch.softmax(outputs,1),1)
            total += labels.shape[0]
            correct += (predicted == labels).sum().item()
    return correct*100/total

def plot_losses(train_losses,valid_losses):
    train_losses = np.array(train_losses)
    valid_losses = np.array(valid_losses)
    
    fig, ax1 = plt.subplots(1, 1)
    ax1.plot(train_losses, color="blue", label="train_loss")
    ax1.plot(valid_losses, color="red", label="valid_loss")
    ax1.set(title="Loss over epochs",
            xlabel="Epoch",
            ylabel="Loss")
    ax1.legend()
    
def plot_accuracy(train_acc,valid_acc):
    train_acc = np.array(train_acc)
    valid_acc = np.array(valid_acc)
    
    fig, ax1 = plt.subplots(1, 1)
    ax1.plot(train_acc, color="blue", label="train_acc")
    ax1.plot(valid_acc, color="red", label="val_acc")
    ax1.set(title="Accuracy over epochs",
            xlabel="Epoch",
            ylabel="Accuracy")
    ax1.legend()

In [10]:
train_set = SkeletonData("data/108_new/train.pkl")
val_set = SkeletonData("data/108_new/val.pkl")
test_set = SkeletonData("data/108_new/test.pkl")

train_loader = DataLoader(train_set,batch_size = 128,drop_last = False)
valid_loader = DataLoader(val_set,batch_size = 512,drop_last = False)
test_loader = DataLoader(test_set,batch_size = 512,drop_last = False)

tensor([33, 10, 10, 36, 38, 39, 35])
tensor([40, 33, 10, 37, 36, 38, 39, 35])
tensor([33, 10, 36, 35])


In [12]:
model = GGCN(find_adjacency_matrix(),41,[3,9],[9, 16, 32, 64],0.5).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
from util.evaluation import early_stopping
device = 'cuda:0'
epochs = 50
train_losses = []
valid_losses = []
train_accuracy = []
val_accuracy = []
for epoch in range(epochs):
    # training
    model,train_loss,optimizer = train(train_loader,model,criterion,optimizer,device)
    
    # validation
    with torch.no_grad():
        model, valid_loss = validate(valid_loader, model, criterion, device)
    train_acc = get_accuracy(model,train_loader,device)
    val_acc = get_accuracy(model,valid_loader,device)
    print("Epoch {} --- Train loss = {} --- Valid loss = {} -- Train set accuracy = {} % Valid set Accuracy = {} %".format
          (epoch+1,train_loss,valid_loss,train_acc,val_acc))
    # save loss value
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    #save accuracy
    train_accuracy.append(train_acc)
    val_accuracy.append(val_acc)
    
    # best_val_accuracy,should_stop = early_stopping(val_accuracy,10)
    
    # if should_stop:
    #     break
    
    # if best_val_accuracy == val_accuracy[-1]:
    #     torch.save(model.state_dict(),"best_model.pth")
    #     print("Save best model ","Best_accuracy = ",get_accuracy(model,test_loader,device))
       



100%|██████████| 1/1 [00:00<00:00, 135.17it/s]
100%|██████████| 1/1 [00:00<00:00, 296.40it/s]


Epoch 1 --- Train loss = 3.576970338821411 --- Valid loss = 3.2583107948303223 -- Train set accuracy = 14.285714285714286 % Valid set Accuracy = 0.0 %


100%|██████████| 1/1 [00:00<00:00, 141.06it/s]
100%|██████████| 1/1 [00:00<00:00, 309.18it/s]


Epoch 2 --- Train loss = 3.0643270015716553 --- Valid loss = 2.978947401046753 -- Train set accuracy = 14.285714285714286 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 146.88it/s]
100%|██████████| 1/1 [00:00<00:00, 263.86it/s]


Epoch 3 --- Train loss = 2.774655818939209 --- Valid loss = 2.847449541091919 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 149.42it/s]
100%|██████████| 1/1 [00:00<00:00, 304.73it/s]


Epoch 4 --- Train loss = 2.009899616241455 --- Valid loss = 2.860274314880371 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 161.29it/s]
100%|██████████| 1/1 [00:00<00:00, 306.53it/s]


Epoch 5 --- Train loss = 1.8919811248779297 --- Valid loss = 2.930762529373169 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 153.11it/s]
100%|██████████| 1/1 [00:00<00:00, 310.62it/s]


Epoch 6 --- Train loss = 2.1113016605377197 --- Valid loss = 3.0272912979125977 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 159.85it/s]
100%|██████████| 1/1 [00:00<00:00, 325.37it/s]


Epoch 7 --- Train loss = 2.0386312007904053 --- Valid loss = 3.1241579055786133 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 161.23it/s]
100%|██████████| 1/1 [00:00<00:00, 323.29it/s]


Epoch 8 --- Train loss = 1.9557768106460571 --- Valid loss = 3.215517044067383 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 169.37it/s]
100%|██████████| 1/1 [00:00<00:00, 324.03it/s]


Epoch 9 --- Train loss = 1.701180100440979 --- Valid loss = 3.307643413543701 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 171.53it/s]
100%|██████████| 1/1 [00:00<00:00, 344.47it/s]


Epoch 10 --- Train loss = 1.8123027086257935 --- Valid loss = 3.3915822505950928 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 174.28it/s]
100%|██████████| 1/1 [00:00<00:00, 374.22it/s]


Epoch 11 --- Train loss = 1.8712809085845947 --- Valid loss = 3.4642128944396973 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 172.17it/s]
100%|██████████| 1/1 [00:00<00:00, 353.09it/s]


Epoch 12 --- Train loss = 1.82351815700531 --- Valid loss = 3.5323829650878906 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 155.87it/s]
100%|██████████| 1/1 [00:00<00:00, 357.14it/s]


Epoch 13 --- Train loss = 1.7804195880889893 --- Valid loss = 3.5988943576812744 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 169.86it/s]
100%|██████████| 1/1 [00:00<00:00, 399.61it/s]


Epoch 14 --- Train loss = 1.9176712036132812 --- Valid loss = 3.657748222351074 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 170.83it/s]
100%|██████████| 1/1 [00:00<00:00, 330.91it/s]


Epoch 15 --- Train loss = 1.703122854232788 --- Valid loss = 3.7127792835235596 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 134.95it/s]
100%|██████████| 1/1 [00:00<00:00, 392.43it/s]


Epoch 16 --- Train loss = 1.718017816543579 --- Valid loss = 3.7625486850738525 -- Train set accuracy = 0.0 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 182.35it/s]
100%|██████████| 1/1 [00:00<00:00, 430.19it/s]


Epoch 17 --- Train loss = 1.8659145832061768 --- Valid loss = 3.810631275177002 -- Train set accuracy = 0.0 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 194.04it/s]
100%|██████████| 1/1 [00:00<00:00, 356.26it/s]


Epoch 18 --- Train loss = 1.8979483842849731 --- Valid loss = 3.8557424545288086 -- Train set accuracy = 0.0 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 170.81it/s]
100%|██████████| 1/1 [00:00<00:00, 371.80it/s]


Epoch 19 --- Train loss = 1.9527925252914429 --- Valid loss = 3.8902528285980225 -- Train set accuracy = 0.0 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 162.02it/s]
100%|██████████| 1/1 [00:00<00:00, 416.60it/s]


Epoch 20 --- Train loss = 1.760746955871582 --- Valid loss = 3.9259767532348633 -- Train set accuracy = 0.0 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 156.60it/s]
100%|██████████| 1/1 [00:00<00:00, 389.73it/s]


Epoch 21 --- Train loss = 1.8987417221069336 --- Valid loss = 3.95367169380188 -- Train set accuracy = 0.0 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 163.90it/s]
100%|██████████| 1/1 [00:00<00:00, 359.29it/s]


Epoch 22 --- Train loss = 1.907375454902649 --- Valid loss = 3.9809134006500244 -- Train set accuracy = 0.0 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 166.71it/s]
100%|██████████| 1/1 [00:00<00:00, 377.90it/s]


Epoch 23 --- Train loss = 1.7742984294891357 --- Valid loss = 4.001198768615723 -- Train set accuracy = 0.0 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 163.40it/s]
100%|██████████| 1/1 [00:00<00:00, 408.96it/s]


Epoch 24 --- Train loss = 2.019768714904785 --- Valid loss = 4.022531032562256 -- Train set accuracy = 14.285714285714286 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 145.93it/s]
100%|██████████| 1/1 [00:00<00:00, 385.36it/s]


Epoch 25 --- Train loss = 1.6780508756637573 --- Valid loss = 4.042186737060547 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 173.91it/s]
100%|██████████| 1/1 [00:00<00:00, 392.36it/s]


Epoch 26 --- Train loss = 1.6978187561035156 --- Valid loss = 4.0597004890441895 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 163.19it/s]
100%|██████████| 1/1 [00:00<00:00, 416.72it/s]


Epoch 27 --- Train loss = 1.8347431421279907 --- Valid loss = 4.073700904846191 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 146.41it/s]
100%|██████████| 1/1 [00:00<00:00, 386.11it/s]


Epoch 28 --- Train loss = 1.6689021587371826 --- Valid loss = 4.086935520172119 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 116.59it/s]
100%|██████████| 1/1 [00:00<00:00, 255.67it/s]


Epoch 29 --- Train loss = 2.099936008453369 --- Valid loss = 4.098018169403076 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 115.77it/s]
100%|██████████| 1/1 [00:00<00:00, 252.47it/s]


Epoch 30 --- Train loss = 1.6169215440750122 --- Valid loss = 4.10741662979126 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 120.17it/s]
100%|██████████| 1/1 [00:00<00:00, 251.41it/s]


Epoch 31 --- Train loss = 1.806649923324585 --- Valid loss = 4.117376804351807 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 123.16it/s]
100%|██████████| 1/1 [00:00<00:00, 257.67it/s]


Epoch 32 --- Train loss = 1.7790597677230835 --- Valid loss = 4.128647804260254 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 124.80it/s]
100%|██████████| 1/1 [00:00<00:00, 265.68it/s]


Epoch 33 --- Train loss = 1.6515657901763916 --- Valid loss = 4.134856224060059 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 120.25it/s]
100%|██████████| 1/1 [00:00<00:00, 255.13it/s]


Epoch 34 --- Train loss = 1.7432504892349243 --- Valid loss = 4.139345169067383 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 125.92it/s]
100%|██████████| 1/1 [00:00<00:00, 252.49it/s]


Epoch 35 --- Train loss = 1.827993631362915 --- Valid loss = 4.144358158111572 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 117.30it/s]
100%|██████████| 1/1 [00:00<00:00, 264.93it/s]


Epoch 36 --- Train loss = 1.818280577659607 --- Valid loss = 4.148002624511719 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 120.70it/s]
100%|██████████| 1/1 [00:00<00:00, 263.89it/s]


Epoch 37 --- Train loss = 1.6674412488937378 --- Valid loss = 4.154806137084961 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 122.78it/s]
100%|██████████| 1/1 [00:00<00:00, 268.20it/s]


Epoch 38 --- Train loss = 1.7804715633392334 --- Valid loss = 4.158304691314697 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 122.61it/s]
100%|██████████| 1/1 [00:00<00:00, 217.86it/s]


Epoch 39 --- Train loss = 1.6534415483474731 --- Valid loss = 4.162496566772461 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 107.24it/s]
100%|██████████| 1/1 [00:00<00:00, 223.39it/s]


Epoch 40 --- Train loss = 2.005502223968506 --- Valid loss = 4.1672587394714355 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 107.59it/s]
100%|██████████| 1/1 [00:00<00:00, 226.72it/s]


Epoch 41 --- Train loss = 1.8229204416275024 --- Valid loss = 4.1765618324279785 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 106.64it/s]
100%|██████████| 1/1 [00:00<00:00, 224.34it/s]


Epoch 42 --- Train loss = 1.8808776140213013 --- Valid loss = 4.185878753662109 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 105.66it/s]
100%|██████████| 1/1 [00:00<00:00, 231.35it/s]


Epoch 43 --- Train loss = 1.829727053642273 --- Valid loss = 4.194519996643066 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 105.45it/s]
100%|██████████| 1/1 [00:00<00:00, 221.20it/s]


Epoch 44 --- Train loss = 2.0091941356658936 --- Valid loss = 4.200605392456055 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 103.85it/s]
100%|██████████| 1/1 [00:00<00:00, 223.80it/s]


Epoch 45 --- Train loss = 1.647979497909546 --- Valid loss = 4.20478630065918 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 105.28it/s]
100%|██████████| 1/1 [00:00<00:00, 227.06it/s]


Epoch 46 --- Train loss = 1.8476203680038452 --- Valid loss = 4.206307888031006 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 129.02it/s]
100%|██████████| 1/1 [00:00<00:00, 253.97it/s]


Epoch 47 --- Train loss = 1.9624783992767334 --- Valid loss = 4.210923194885254 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 164.26it/s]
100%|██████████| 1/1 [00:00<00:00, 332.91it/s]


Epoch 48 --- Train loss = 2.054805278778076 --- Valid loss = 4.2110595703125 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 167.00it/s]
100%|██████████| 1/1 [00:00<00:00, 338.11it/s]


Epoch 49 --- Train loss = 1.9694898128509521 --- Valid loss = 4.213414192199707 -- Train set accuracy = 28.571428571428573 % Valid set Accuracy = 12.5 %


100%|██████████| 1/1 [00:00<00:00, 171.67it/s]
100%|██████████| 1/1 [00:00<00:00, 349.55it/s]

Epoch 50 --- Train loss = 1.8734970092773438 --- Valid loss = 4.216571807861328 -- Train set accuracy = 42.857142857142854 % Valid set Accuracy = 12.5 %



