In [39]:
# hold to your courage hold to your wits
# Imports
import torch

import torch.nn as nn
import torch
import torch.optim as optim
from torchvision import transforms
from tqdm import tqdm
#from torchsummary import summary
import math

import torch.nn.functional as F

In [40]:
#device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [41]:
img_channel=3
img_height=32
img_width=100
num_class=128+1
map_to_seq_hidden=64
rnn_hidden=128 + 15

In [42]:
class ECABlock(nn.Module):
    def __init__(self, channels, k_size=3):
        super(ECABlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1d = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)  
        y = y.squeeze(-1).transpose(-1, -2)  # (bs, c, 1) -> (bs, 1, c)
        y = self.conv1d(y)  
        y = self.sigmoid(y)  
        y = y.transpose(-1, -2).unsqueeze(-1)  #  (bs, 1, c) - (bs, c, 1, 1)
        y = y.expand_as(x)  
        return x * y  # why this?


In [43]:
# Model

class CRNN_2(nn.Module):

    def __init__(self, img_channel=3, img_height=28, img_width=128, num_class=128+1,
                 map_to_seq_hidden=64, rnn_hidden=128, leaky_relu=False):
        super(CRNN_2, self).__init__()
        

        self.cnn, (output_channel, output_height, output_width) = \
            self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)

        self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)

        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True )
        self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True,num_layers=2,dropout=0.5) #,num_layers=2,dropout=0.4

        self.dense = nn.Linear(2 * rnn_hidden, num_class)

    def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        channels = [img_channel, 64, 128, 256, 256,350, 400 ,400]
        kernel_sizes = [3, 3, 3, 3, 3, 3,2]
        strides = [1, 1, 1,1, 1,1 ,1]
        paddings = [1, 1, 1, 1 ,1,1, 0]

        cnn = nn.Sequential()

        def conv_relu(i, batch_norm=False):
            # shape of input: (batch, input_channel, height, width)
            input_channel = channels[i]
            output_channel = channels[i+1]

            cnn.add_module(
                f'conv{i}',
                nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
            )

            if batch_norm:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))

            relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
            cnn.add_module(f'relu{i}', relu)
            if output_channel > 256:
                cnn.add_module(f'attention{i}',ECABlock(output_channel))

        # size of image: (channel, height, width) = (img_channel, img_height, img_width)
        conv_relu(0)
        
        cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
        # (64, img_height // 2, img_width // 2)

        conv_relu(1)
        cnn.add_module('dropout0',nn.Dropout2d(p=0.2)) # Dropout
        cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
        # (128, img_height // 4, img_width // 4)

        conv_relu(2)
        
        cnn.add_module('dropout1',nn.Dropout2d(p=0.2)) # Dropout
        conv_relu(3)
        cnn.add_module('dropout2',nn.Dropout2d(p=0.3))
        cnn.add_module(
            'pooling2',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (256, img_height // 8, img_width // 4)

        conv_relu(4, batch_norm=True)
        cnn.add_module('dropout2',nn.Dropout2d(p=0.4))
        conv_relu(5, batch_norm=True)
        cnn.add_module('dropout2',nn.Dropout2d(p=0.4))
        cnn.add_module(
            'pooling3',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (512, img_height // 16, img_width // 4)

        conv_relu(6)  # (512, img_height // 16 - 1, img_width // 4 - 1)

        output_channel, output_height, output_width = \
            channels[-1], img_height // 16-1 , img_width // 4 -1
        return cnn, (output_channel, output_height, output_width)

    def forward(self, images):
        # shape of images: (batch, channel, height, width)

        conv = self.cnn(images)
        batch, channel, height, width = conv.size()

        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(2, 0, 1)  # (width, batch, feature)
        seq = self.map_to_seq(conv)

        recurrent, _ = self.rnn1(seq)
        recurrent, _ = self.rnn2(recurrent)

        output = self.dense(recurrent)
        return output  # shape: (seq_len, batch, num_class)

In [44]:
# load pretrained model if available
#model_2 = torch.load('model_2.pth')

In [45]:
model_2 = CRNN_2( img_channel, img_height, img_width, num_class,map_to_seq_hidden, rnn_hidden).to(device)

In [46]:
# transform
transform= transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32,128)),
    transforms.RandomRotation(degrees=20),
])

In [47]:
# Hyperparams
batch_size =512
lr = 0.0003
epochs = 200
criterion = nn.CTCLoss()
optimiser = optim.Adam(model_2.parameters(),lr=lr, weight_decay=1e-5)
#optimiser = torch.optim.AdamW(model_2.parameters(), lr=0.0003)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimiser, 'min', patience=10, factor=0.1)


In [48]:
#data loders
import os
from PIL import Image
from torch.utils.data import Dataset,random_split
# from matplotlib import pyplot as plt

class CustomImageTextDataset(Dataset):
    def __init__(self, image_dir, image_names, texts, transform=None):
        self.image_dir = image_dir
        self.image_names = image_names
        self.texts = texts
        self.transform = transform

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_path =    self.image_names[idx]
        image = Image.open(img_path).convert('RGB')
        text = self.texts[idx]

        if self.transform:
            image = self.transform(image)

        return image, text

##
class mydataset:
    def __init__(self,data_dir,transform=None,val=False):
        #paths to the directories/files
        extraction_path = data_dir
        image_dir = os.path.join(extraction_path, 'images')
        text_file_path = os.path.join(extraction_path, 'train_labels.txt' if val==False else 'val_labels.txt')
        image_names_file_path = os.path.join(extraction_path, 'train_images.txt' if val==False else 'val_images.txt')

        #read the image names
        with open(image_names_file_path, 'r',encoding='utf-8') as f:
            image_names = f.read().splitlines()

        #read the texts
        with open(text_file_path, 'r',encoding='utf-8') as f:
            texts = f.read().splitlines()

        # ensure that image_names and texts have the same length
        assert len(image_names) == len(texts), "Mismatch between image names and texts"
        # retrurn object of you clss
        self.dataset= CustomImageTextDataset(image_dir=image_dir, image_names=image_names, texts=texts, transform=transform)

    def get_dataset(self):
        return self.dataset

# # dataset
from torch.utils.data import DataLoader


#dataset->train
data_train = mydataset(data_dir='train',transform=transform)
data_train = data_train.get_dataset()
data_train,_ = random_split(data_train,[len(data_train)//1,len(data_train)-len(data_train)//1])

#dataset->val
data_val = mydataset(data_dir='val',transform=transform,val=True)
data_val = data_val.get_dataset()

 #dataloaders
trainloader = DataLoader(data_train,shuffle=True,batch_size=batch_size)
valloader = DataLoader(data_val,shuffle=True,batch_size=batch_size)

In [49]:
print(valloader.__len__())
trainloader.__len__()

5


35

In [50]:
# Label Transformer
class LabelTransformer():
    """
    encoder and decoder

    Args:
        letters (str): Letters contained in the data
    """

    def __init__(self, letters= "ಂಃಅಆಇಈಉಊಋಌಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಱಲಳವಶಷಸಹ಼ಽಾಿೀುೂೃೄೆೇೈೊೋೌ್ೕೖೞೠೡೢೣ೦೧೨೩೪೫೬೭೮೯ೱೲ"):
        self.encode_map = {letter: idx+1 for idx, letter in enumerate(letters)}
        self.decode_map = ' ' + letters

    # def encode(self, text):
    #     if isinstance(text, str):
    #         length = [len(text)]
    #         result = [self.encode_map[letter] for letter in text]
    #     else:
    #         length = []
    #         result = []
    #         for word in text:
    #             length.append(len(word))
    #             result.extend([self.encode_map[letter] for letter in word])
    #     return torch.IntTensor(result), torch.IntTensor(length)
    def encode(self, text):
        result = []
        length = []
        for word in text:
            length.append(len(word))
            for letter in word:
                if letter in self.encode_map:
                    result.append(self.encode_map[letter])
                else:
                    # Log the problematic letter and the word it was part of
                    # print(f"Warning: Character '{letter}' not found in encode_map. In word: '{word}'")
                    # Optionally, handle the unknown character
                    result.append(self.encode_map.get('unknown', 0))  # Use a default value or handle as needed
        return torch.IntTensor(result), torch.IntTensor(length)

    def decode(self, text_code, length):
        result = []
        idx = 0
        for len in length:
            word = []
            for i in range(len):
                if text_code[idx] != 0 and (i == 0 or text_code[idx] != text_code[idx - 1]):
                    word.append(self.decode_map[text_code[idx]])
                idx += 1
            result.append(''.join(word))
        return result


In [51]:
# CTC decoder
def ctc_decoder(predictions, label_transformer, blank=0):
    decoded_output = []
    batch_size = predictions.size(1)
    max_seq_length = predictions.size(0)

    # Get the indices of the max probabilities (predictions)
    _, max_indices = torch.max(predictions, 2)

    for batch in range(batch_size):
        pred_string = []
        previous_char = None
        for t in range(max_seq_length):
            current_char = max_indices[t][batch].item()
            if current_char != blank and current_char != previous_char:
                pred_string.append(current_char)
            previous_char = current_char
        decoded_output.append(''.join([label_transformer.decode_map[idx] for idx in pred_string if idx != blank]))

    return decoded_output

In [52]:
# val acc
def val_acc(model,loader,labeltransform):
  model.eval()
  with torch.no_grad():
    total_correct = 0
    total_samples = 0
    for data, target in tqdm(loader):
        data = data.to(device)

        pred = model(data)  # Model output shape: (seq_len, batch_size, num_classes)

        # Decode predictions
        pred_decoded = ctc_decoder(pred, labeltransform)

        # Compute the number of correct predictions
        correct_predictions = sum(p == t for p, t in zip(pred_decoded, target))
        total_correct += correct_predictions
        total_samples += len(target)

    # Calculate accuracy
    acc = total_correct *100/ total_samples
    print(f'Validation Accuracy: {acc:.4f}')
    return acc


In [53]:
# # Training Loop goes here
# loss=[]
# epoch_loss=0
# labeltransform = LabelTransformer()

# for epoch in range(epochs):
#     model_2.train()
#     print("============== epoch :",epoch,"=============")
#     for idx,(data,targets) in tqdm(enumerate(trainloader)):
#         data = data.to(device)
#         x = targets
#         # Output
#         targets,target_length = labeltransform.encode(targets)
#         optimiser.zero_grad()
#         # out
#         out = model_2(data)
#         out_length = torch.IntTensor([out.size(0)]*out.size(1))

#         # Loss
#         loss = criterion(F.log_softmax(out,2),targets,out_length,target_length)
#         loss.backward()
#         # Gradient Clipping
#         #torch.nn.utils.clip_grad_norm_(model_2.parameters(), max_norm=1)
#         # Optim
#         optimiser.step()

#         epoch_loss += loss.item()
#     # After epoch loss calculation
#     scheduler.step(epoch_loss / len(trainloader))
#     print("epoch :",epoch,"loss",epoch_loss/len(trainloader))

#     if(epoch % 5 == 0):
#       acc = val_acc(model_2,valloader,labeltransform)
#       val_acc(model_2,trainloader,labeltransform)
#     if(acc > 75):
#         #save the trained model
#         torch.save(model_2.state_dict(), 'model_75+.pth')
#         break
#     loss.append(epoch_loss)
#     epoch_loss=0



# Training Loop goes here
epoch_losses=[]
epoch_loss=0
labeltransform = LabelTransformer()

for epoch in range(epochs):
    model_2.train()
    print("============== epoch :",epoch,"=============")
    for idx,(data,targets) in tqdm(enumerate(trainloader)):
        data = data.to(device)
        x = targets
        # Output
        targets,target_length = labeltransform.encode(targets)
        optimiser.zero_grad()
        # out
        out = model_2(data)
        out_length = torch.IntTensor([out.size(0)]*out.size(1))

        # Loss
        loss_value = criterion(F.log_softmax(out,2),targets,out_length,target_length)
        loss_value.backward()
        # Gradient Clipping
        #torch.nn.utils.clip_grad_norm_(model_2.parameters(), max_norm=1)
        # Optim
        optimiser.step()

        epoch_loss += loss_value.item()
    # After epoch loss calculation
    scheduler.step(epoch_loss / len(trainloader))
    print("epoch :",epoch,"loss",epoch_loss/len(trainloader))

    if(epoch % 5 == 0):
      acc = val_acc(model_2,valloader,labeltransform)
      val_acc(model_2,trainloader,labeltransform)
    if(acc > 75):
        #save the trained model
        torch.save(model_2.state_dict(), 'model_75+.pth')
        break
    epoch_losses.append(epoch_loss)
    epoch_loss=0



35it [03:52,  6.66s/it]


epoch : 0 loss 12.018446241106306


100%|██████████| 5/5 [00:30<00:00,  6.07s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [00:48<00:00,  1.38s/it]


Validation Accuracy: 0.0114


35it [04:25,  7.60s/it]


epoch : 1 loss 3.9659824848175047


35it [05:30,  9.45s/it]


epoch : 2 loss 3.8376657962799072


35it [06:38, 11.39s/it]


epoch : 3 loss 3.7994773728506908


35it [06:26, 11.04s/it]


epoch : 4 loss 3.763892800467355


35it [06:38, 11.40s/it]


epoch : 5 loss 3.7342359951564243


100%|██████████| 5/5 [01:06<00:00, 13.38s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [04:41<00:00,  8.06s/it]


Validation Accuracy: 0.0114


35it [06:51, 11.76s/it]


epoch : 6 loss 3.715666811806815


35it [06:46, 11.63s/it]


epoch : 7 loss 3.6965186732155937


35it [06:41, 11.48s/it]


epoch : 8 loss 3.6986708777291435


35it [06:17, 10.78s/it]


epoch : 9 loss 3.6800170489719934


35it [06:04, 10.41s/it]


epoch : 10 loss 3.665705108642578


100%|██████████| 5/5 [00:36<00:00,  7.25s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [04:17<00:00,  7.35s/it]


Validation Accuracy: 0.0114


35it [06:05, 10.43s/it]


epoch : 11 loss 3.657414143426078


35it [06:18, 10.82s/it]


epoch : 12 loss 3.810915429251535


35it [05:52, 10.06s/it]


epoch : 13 loss 3.666113267626081


35it [05:38,  9.67s/it]


epoch : 14 loss 3.6460169655936103


35it [06:51, 11.77s/it]


epoch : 15 loss 3.809569617680141


100%|██████████| 5/5 [00:36<00:00,  7.37s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [04:18<00:00,  7.39s/it]


Validation Accuracy: 0.0114


35it [06:56, 11.91s/it]


epoch : 16 loss 3.686087077004569


35it [06:32, 11.22s/it]


epoch : 17 loss 3.64206668308803


35it [06:33, 11.23s/it]


epoch : 18 loss 3.6398325647626604


35it [06:17, 10.77s/it]


epoch : 19 loss 3.634634576525007


35it [06:42, 11.50s/it]


epoch : 20 loss 3.629819849559239


100%|██████████| 5/5 [00:35<00:00,  7.13s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [04:16<00:00,  7.31s/it]


Validation Accuracy: 0.0114


35it [06:08, 10.52s/it]


epoch : 21 loss 3.629171528135027


35it [06:12, 10.65s/it]


epoch : 22 loss 3.6262703554970876


35it [05:34,  9.55s/it]


epoch : 23 loss 3.6216349601745605


35it [05:29,  9.42s/it]


epoch : 24 loss 3.6215459414890834


35it [05:29,  9.40s/it]


epoch : 25 loss 3.62020708492824


100%|██████████| 5/5 [00:28<00:00,  5.73s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [03:20<00:00,  5.72s/it]


Validation Accuracy: 0.0114


35it [05:29,  9.41s/it]


epoch : 26 loss 3.614537763595581


35it [05:29,  9.41s/it]


epoch : 27 loss 3.6113827432904926


35it [05:29,  9.41s/it]


epoch : 28 loss 3.6064945016588483


35it [05:29,  9.41s/it]


epoch : 29 loss 3.5996040139879497


35it [05:29,  9.40s/it]


epoch : 30 loss 3.5800095285688127


100%|██████████| 5/5 [00:28<00:00,  5.72s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [03:20<00:00,  5.72s/it]


Validation Accuracy: 0.0114


35it [05:28,  9.39s/it]


epoch : 31 loss 3.5719607489449636


35it [05:28,  9.39s/it]


epoch : 32 loss 3.56241329738072


35it [05:32,  9.51s/it]


epoch : 33 loss 3.5459111077444896


35it [05:29,  9.41s/it]


epoch : 34 loss 3.5383300713130406


35it [05:30,  9.43s/it]


epoch : 35 loss 3.615359122412545


100%|██████████| 5/5 [00:35<00:00,  7.14s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [03:33<00:00,  6.11s/it]


Validation Accuracy: 0.0114


35it [05:52, 10.07s/it]


epoch : 36 loss 3.6047399452754427


35it [05:47,  9.94s/it]


epoch : 37 loss 3.554686198915754


35it [05:57, 10.22s/it]


epoch : 38 loss 3.536082567487444


35it [05:32,  9.51s/it]


epoch : 39 loss 3.5325555528913224


35it [05:32,  9.49s/it]


epoch : 40 loss 3.6265627929142545


100%|██████████| 5/5 [00:29<00:00,  5.86s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [03:22<00:00,  5.79s/it]


Validation Accuracy: 0.0114


35it [05:32,  9.50s/it]


epoch : 41 loss 3.68743143762861


35it [05:32,  9.49s/it]


epoch : 42 loss 3.5983219282967704


35it [05:32,  9.49s/it]


epoch : 43 loss 3.586940424782889


35it [05:32,  9.50s/it]


epoch : 44 loss 3.9081858158111573


35it [05:32,  9.49s/it]


epoch : 45 loss 3.615515763419015


100%|██████████| 5/5 [00:28<00:00,  5.79s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [03:22<00:00,  5.79s/it]


Validation Accuracy: 0.0114


35it [05:32,  9.50s/it]


epoch : 46 loss 3.5905037266867503


35it [05:32,  9.50s/it]


epoch : 47 loss 3.688518639973232


35it [05:32,  9.49s/it]


epoch : 48 loss 3.595443493979318


35it [05:32,  9.49s/it]


epoch : 49 loss 3.5865071501050676


35it [05:50, 10.01s/it]


epoch : 50 loss 3.5805677345820834


100%|██████████| 5/5 [00:35<00:00,  7.07s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [04:08<00:00,  7.11s/it]


Validation Accuracy: 0.0057


35it [06:20, 10.87s/it]


epoch : 51 loss 3.579515436717442


35it [06:08, 10.54s/it]


epoch : 52 loss 3.58171592439924


35it [06:02, 10.36s/it]


epoch : 53 loss 3.5797931262425013


35it [06:21, 10.89s/it]


epoch : 54 loss 3.6123058523450577


35it [06:15, 10.72s/it]


epoch : 55 loss 3.674945422581264


100%|██████████| 5/5 [00:39<00:00,  7.99s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [04:32<00:00,  7.78s/it]


Validation Accuracy: 0.0114


35it [06:07, 10.51s/it]


epoch : 56 loss 3.890699277605329


35it [06:04, 10.41s/it]


epoch : 57 loss 3.940535511289324


35it [06:02, 10.36s/it]


epoch : 58 loss 3.8961835656847272


35it [06:00, 10.30s/it]


epoch : 59 loss 3.8507874965667725


35it [06:01, 10.32s/it]


epoch : 60 loss 3.8056130273001534


100%|██████████| 5/5 [00:38<00:00,  7.65s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [04:16<00:00,  7.32s/it]


Validation Accuracy: 0.0114


35it [06:00, 10.30s/it]


epoch : 61 loss 3.7623586927141464


35it [06:09, 10.56s/it]


epoch : 62 loss 3.7382443700517927


35it [07:36, 13.05s/it]


epoch : 63 loss 3.7375751835959297


35it [07:25, 12.74s/it]


epoch : 64 loss 3.732443346296038


35it [05:58, 10.24s/it]


epoch : 65 loss 3.7291099957057408


100%|██████████| 5/5 [00:38<00:00,  7.65s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [04:15<00:00,  7.30s/it]


Validation Accuracy: 0.0114


35it [05:59, 10.26s/it]


epoch : 66 loss 3.723175539289202


35it [05:58, 10.24s/it]


epoch : 67 loss 3.7237564836229597


35it [05:57, 10.22s/it]


epoch : 68 loss 3.7199020862579344


35it [06:00, 10.29s/it]


epoch : 69 loss 3.7271373203822544


35it [06:04, 10.43s/it]


epoch : 70 loss 3.7204137529645647


100%|██████████| 5/5 [00:29<00:00,  5.82s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [03:24<00:00,  5.83s/it]


Validation Accuracy: 0.0114


35it [05:37,  9.64s/it]


epoch : 71 loss 3.7219226905277796


35it [05:34,  9.57s/it]


epoch : 72 loss 3.7201955795288084


35it [05:34,  9.57s/it]


epoch : 73 loss 3.7165046623774938


35it [05:47,  9.92s/it]


epoch : 74 loss 3.7171919209616524


35it [05:49,  9.99s/it]


epoch : 75 loss 3.718807772227696


100%|██████████| 5/5 [00:39<00:00,  7.83s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [05:38<00:00,  9.66s/it]


Validation Accuracy: 0.0114


35it [05:32,  9.51s/it]


epoch : 76 loss 3.7191260474068777


35it [05:32,  9.49s/it]


epoch : 77 loss 3.717685890197754


35it [05:45,  9.89s/it]


epoch : 78 loss 3.718906000682286


35it [06:31, 11.20s/it]


epoch : 79 loss 3.71404606955392


35it [06:23, 10.97s/it]


epoch : 80 loss 3.717430019378662


100%|██████████| 5/5 [00:49<00:00,  9.89s/it]


Validation Accuracy: 0.0000


100%|██████████| 35/35 [04:22<00:00,  7.50s/it]


Validation Accuracy: 0.0114


35it [06:15, 10.72s/it]


epoch : 81 loss 3.7159846782684327


35it [12:59, 22.27s/it] 


epoch : 82 loss 3.718151208332607


28it [37:29, 16.07s/it] 

In [16]:
#save the trained model
torch.save(model_2.state_dict(), 'model_76.1.pth')

In [17]:
print(val_acc(model_2,valloader,labeltransform))
val_acc(model_2,trainloader,labeltransform)

  0%|          | 0/5 [00:00<?, ?it/s]


FileNotFoundError: [Errno 2] No such file or directory: 'val/images/15122022_L_GH042772_image_000103_23.png'