In [None]:
!pip install timm
import pandas as pd

Collecting timm
  Downloading timm-0.9.12-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: timm
Successfully installed timm-0.9.12


In [None]:
descriptions = pd.read_csv('/content/Category Description.csv').iloc[:10]

In [None]:
descriptions

Unnamed: 0,Catgory,Prompt 1,Prompt 2,Prompt 3,Prompt 4
0,Airplane,An airplane is a powered flying vehicle with f...,"An airplane is a complex, powered aircraft des...","An airplane is comprised of a fuselage, wings,...",An airplane is an intricate assembly of variou...
1,Automobile,An automobile is a wheeled motor vehicle typic...,"An automobile, commonly known as a car, is a s...",An automobile typically consists of an engine ...,An automobile is a complex assembly of various...
2,Bird,"A bird is a warm-blooded, feathered vertebrate...",A bird is a type of warm-blooded vertebrate ch...,"A bird typically has feathers, wings, a beak, ...",A bird's anatomy is a unique and specialized a...
3,Cat,"A cat is a small, domesticated, carnivorous ma...","A cat, scientifically known as Felis catus, is...","A cat has fur, whiskers, retractable claws, sh...","A cat, a small carnivorous mammal, possesses s..."
4,Deer,"A deer is a hoofed, herbivorous mammal known f...","A deer is a member of the Cervidae family, a g...","A deer typically has a slender body, fur, antl...",A deer is an elegant and adaptable animal with...
5,Dog,A dog is a domesticated mammal known for its l...,"A dog, scientifically known as Canis lupus fam...","A dog typically has fur, a tail, four legs, pa...","A dog, a domesticated canine and a popular pet..."
6,Frog,"A frog is a small, tailless amphibian with lon...","A frog is a member of the amphibian class, par...","A frog has a short, moist body, long hind legs...","A frog, a small and versatile amphibian, posse..."
7,Horse,"A horse is a large, herbivorous mammal known f...","A horse is a large, domesticated mammal known ...","A horse has a large body covered in fur, a lon...","A horse, a large and powerful herbivorous mamm..."
8,Ship,A ship is a large watercraft designed for mari...,"A ship is a substantial watercraft, significan...","A ship typically consists of a hull, deck, pro...","A ship, a complex and large watercraft, compri..."
9,truck,A truck is a motorized vehicle designed primar...,A truck is a motor vehicle designed for the ef...,A truck typically consists of a cab for the dr...,A truck is a versatile motor vehicle with seve...


In [None]:
l2c = dict(descriptions.Catgory)
l2c

{0: 'Airplane',
 1: 'Automobile',
 2: 'Bird',
 3: 'Cat',
 4: 'Deer',
 5: 'Dog',
 6: 'Frog',
 7: 'Horse',
 8: 'Ship',
 9: 'truck'}

### Offline Text Encoder

**BERT**

In [None]:
from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [None]:
def bert_tokenize(col):
    sentences = col.tolist()
    return tokenizer(sentences, padding=True, return_tensors="pt")

In [None]:
original_text_tokens = bert_tokenize(descriptions.iloc[:,1])['input_ids']
original_attention_masks = bert_tokenize(descriptions.iloc[:,1])['attention_mask']

In [None]:
text_tokens = bert_tokenize(descriptions.iloc[:,1])['input_ids']

In [None]:
attention_masks = bert_tokenize(descriptions.iloc[:,1])['attention_mask']

### DataLoaders

In [None]:
import numpy as np

import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

from torch.utils.data import Dataset, DataLoader

class Image_text_dataset(Dataset):
    def __init__(self, img_set, text_embeddings, attention_masks):
        self.img_set = img_set
        self.text_tokens = text_tokens
        self.attention_masks = attention_masks

    def __len__(self):
        return len(self.text_tokens)

    def __getitem__(self, idx):
        img, label = self.img_set[idx]
        text_tokens = self.text_tokens[label]
        attention_masks = self.attention_masks[label]
        return img, label, text_tokens, attention_masks


def get_train_valid_loader(data_dir,
                           batch_size,
                           augment=False,
                           random_seed=0,
                           valid_size=0.2,
                           shuffle=True):


    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    # define transforms
    valid_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
    ])
    if augment:
        train_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    # load the dataset
    train_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=train_transform,
    )

    valid_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=train_transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
         Image_text_dataset(train_dataset, text_tokens, attention_masks), batch_size=batch_size, sampler=train_sampler
    )
    valid_loader = torch.utils.data.DataLoader(
        Image_text_dataset(valid_dataset, text_tokens, attention_masks), batch_size=batch_size, sampler=valid_sampler
    )

    return (train_loader, valid_loader)


def get_test_loader(data_dir,
                    batch_size,
                    shuffle=True):

    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )

    # define transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])

    dataset = datasets.CIFAR10(
        root=data_dir, train=False,
        download=True, transform=transform,
    )

    data_loader = torch.utils.data.DataLoader(
        Image_text_dataset(dataset, text_tokens), batch_size=batch_size, shuffle=shuffle
    )

    return data_loader

In [None]:
batch_size = 32

train_loader, val_loader = get_train_valid_loader('./data', batch_size, augment=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:10<00:00, 15835663.34it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
from torchvision.models import resnet50, ResNet50_Weights

rn50 = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:01<00:00, 74.4MB/s]


In [None]:
x = torch.rand((3,224,224))
rn50(x.unsqueeze(0)).shape

torch.Size([1, 1000])

### Image and Text Encoder

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import timm
from transformers import DistilBertModel, DistilBertConfig


class ImageEncoder(nn.Module):
    """
    Encode images to a fixed size vector
    """

    def __init__(
        self, model_name='resnet50', pretrained=True, trainable=True
    ):
        super().__init__()
        self.model = timm.create_model(
            model_name, pretrained, num_classes=0, global_pool="avg"
        )
        for idx, p in enumerate(self.model.parameters()):
            if idx == len(list(self.model.parameters())) - 2:
                p.requires_grad = trainable
            else:
                p.requires_grad = False
            #p.requires_grad = True

    def forward(self, x):
        return self.model(x)


class TextEncoder(nn.Module):
    def __init__(self, model_name="distilbert-base-uncased", pretrained=True, trainable=True):
        super().__init__()
        if pretrained:
            self.model = DistilBertModel.from_pretrained(model_name)
        else:
            self.model = DistilBertModel(config=DistilBertConfig())

        for p in self.model.parameters():
            p.requires_grad = trainable

        # CLS token hidden as the sentence embedding
        self.target_token_idx = 0

    def forward(self, input_ids, attention_mask):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = output.last_hidden_state
        return last_hidden_state[:, self.target_token_idx, :]


class ProjectionHead(nn.Module):
    def __init__(
        self,
        embedding_dim,
        projection_dim=256,
        dropout=0.1
    ):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x

### Text Enhanced Image Classifier

In [None]:
!pip install info-nce-pytorch
from info_nce import InfoNCE, info_nce

Collecting info-nce-pytorch
  Downloading info_nce_pytorch-0.1.4-py3-none-any.whl (4.8 kB)
Installing collected packages: info-nce-pytorch
Successfully installed info-nce-pytorch-0.1.4


In [None]:
import random

loss_contrastive = InfoNCE(negative_mode='paired')
loss_clf = nn.CrossEntropyLoss()

class CLIPModel(nn.Module):
    def __init__(
        self,
        temperature=1.0,
        image_embedding=2048,
        text_embedding=768,
        clf_loss_weight=0.8
    ):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding)
        self.temperature = temperature
        self.clf = nn.Linear(image_embedding,10)
        self.clf_loss_weight = clf_loss_weight

    def forward(self, img, label, text, att, evaluation=False):
        #img and text features
        image_features = self.image_encoder(img)
        text_features = self.text_encoder(
            input_ids=text, attention_mask=att
        )
        #project to same dim
        image_embeddings = self.image_projection(image_features)
        text_embeddings = self.text_projection(text_features)

        if evaluation: return F.softmax(self.clf(image_features),dim=-1)

        #sample negative texts
        neg_labels = torch.LongTensor([random.choice(list(range(i))+list(range(i+1,10))) for i in labels])
        neg_texts_tokens = original_text_tokens[neg_labels].to(device)
        neg_texts_embeddings = self.text_encoder(neg_texts_tokens, original_attention_masks[neg_labels].to(device))
        neg_texts_embeddings = self.text_projection(neg_texts_embeddings).unsqueeze(1)

        #contrastive loss: InfoNCE
        contrastive_loss = loss_contrastive(image_embeddings, text_embeddings, neg_texts_embeddings)

        #classification loss: CE
        logits = F.softmax(self.clf(image_features),dim=-1)
        clf_loss = loss_clf(logits, labels)

        #combine loss
        loss = (1-self.clf_loss_weight) * contrastive_loss + self.clf_loss_weight * clf_loss

        return logits, loss

In [None]:
import torch.optim as optim
from tqdm import tqdm
from sklearn.metrics import classification_report
device = 'cuda' if torch.cuda.is_available() else 'cpu'

EPOCH = 10
model = CLIPModel(clf_loss_weight=0.2).to(device)
optimizer = optim.AdamW(model.parameters(), lr=5e-3)

accs = []
for epoch in range(EPOCH):
  total_loss = 0
  for i, batch in enumerate(tqdm(train_loader)):
      imgs, labels, text_tokens, att = batch

      images = imgs.to(device)
      texts = text_tokens.to(device)
      labels = labels.to(device)
      att = att.to(device)

      optimizer.zero_grad()
      logits, loss = model(images, labels, texts, att)
      loss.backward()
      optimizer.step()

      total_loss+=loss.item()

      if not i%100:
        #validation
        model.eval()
        correct = 0
        with torch.no_grad():
            for batch in val_loader:
                imgs, labels, text_tokens, att = batch

                images = imgs.to(device)
                texts = text_tokens.to(device)
                labels = labels.to(device)
                att = att.to(device)

                logits = model(images, labels, texts, att, evaluation=True)
                pred = logits.argmax(-1)
                correct += (pred==labels).float().sum()

        acc = correct/(32*len(val_loader))
        accs.append(acc)
        print(f'Epoch {epoch+1} Batch {i+1} Training Loss: {total_loss/(i+1)} Validation Accuracy: {acc}')

  print(f'Epoch {epoch+1} Avg Training Loss: {total_loss/len(train_loader)}')

  0%|          | 1/1250 [00:39<13:37:38, 39.28s/it]

Epoch 1 Batch 1 Training Loss: 1.0213509798049927 Validation Accuracy: 0.1843051016330719


  8%|▊         | 101/1250 [01:59<4:08:56, 13.00s/it]

Epoch 1 Batch 101 Training Loss: 0.9187669759929771 Validation Accuracy: 0.8442491888999939


 16%|█▌        | 201/1250 [03:19<3:44:50, 12.86s/it]

Epoch 1 Batch 201 Training Loss: 0.9006485215467007 Validation Accuracy: 0.8580271601676941


 24%|██▍       | 301/1250 [04:40<3:22:01, 12.77s/it]

Epoch 1 Batch 301 Training Loss: 0.8930887116546251 Validation Accuracy: 0.8595247268676758


 32%|███▏      | 401/1250 [06:00<3:01:26, 12.82s/it]

Epoch 1 Batch 401 Training Loss: 0.8887579879558591 Validation Accuracy: 0.8722044825553894


 40%|████      | 501/1250 [07:21<2:39:32, 12.78s/it]

Epoch 1 Batch 501 Training Loss: 0.8860938700135359 Validation Accuracy: 0.860223650932312


 48%|████▊     | 601/1250 [08:41<2:19:01, 12.85s/it]

Epoch 1 Batch 601 Training Loss: 0.8838308296862141 Validation Accuracy: 0.8772963285446167


 56%|█████▌    | 701/1250 [10:01<1:56:57, 12.78s/it]

Epoch 1 Batch 701 Training Loss: 0.8823620699951891 Validation Accuracy: 0.878494381904602


 64%|██████▍   | 801/1250 [11:22<1:35:44, 12.79s/it]

Epoch 1 Batch 801 Training Loss: 0.8813324558749777 Validation Accuracy: 0.8753993511199951


 72%|███████▏  | 901/1250 [12:42<1:14:17, 12.77s/it]

Epoch 1 Batch 901 Training Loss: 0.880417318971255 Validation Accuracy: 0.8801916837692261


 80%|████████  | 1001/1250 [14:02<53:08, 12.81s/it]

Epoch 1 Batch 1001 Training Loss: 0.8796275472426629 Validation Accuracy: 0.8807907104492188


 88%|████████▊ | 1101/1250 [15:23<31:44, 12.78s/it]

Epoch 1 Batch 1101 Training Loss: 0.8790037576768963 Validation Accuracy: 0.8844848275184631


 96%|█████████▌| 1201/1250 [16:43<10:29, 12.84s/it]

Epoch 1 Batch 1201 Training Loss: 0.8784327310983784 Validation Accuracy: 0.8852835297584534


100%|██████████| 1250/1250 [17:02<00:00,  1.22it/s]


Epoch 1 Avg Training Loss: 0.8782066305637359


  0%|          | 1/1250 [00:42<14:34:20, 42.00s/it]

Epoch 2 Batch 1 Training Loss: 0.8670360445976257 Validation Accuracy: 0.8791932463645935


  8%|▊         | 101/1250 [02:02<4:05:22, 12.81s/it]

Epoch 2 Batch 101 Training Loss: 0.8665854505973287 Validation Accuracy: 0.8885782361030579


 16%|█▌        | 201/1250 [03:22<3:43:44, 12.80s/it]

Epoch 2 Batch 201 Training Loss: 0.8671144567318817 Validation Accuracy: 0.8865814805030823


 24%|██▍       | 301/1250 [04:43<3:22:50, 12.82s/it]

Epoch 2 Batch 301 Training Loss: 0.8676710815920782 Validation Accuracy: 0.8847843408584595


 32%|███▏      | 401/1250 [06:03<3:01:02, 12.79s/it]

Epoch 2 Batch 401 Training Loss: 0.8677094907237407 Validation Accuracy: 0.8772963285446167


 40%|████      | 501/1250 [07:23<2:39:38, 12.79s/it]

Epoch 2 Batch 501 Training Loss: 0.8676693096846163 Validation Accuracy: 0.8869808316230774


 48%|████▊     | 601/1250 [08:44<2:18:47, 12.83s/it]

Epoch 2 Batch 601 Training Loss: 0.8677858445489665 Validation Accuracy: 0.888877809047699


 56%|█████▌    | 701/1250 [10:04<1:57:01, 12.79s/it]

Epoch 2 Batch 701 Training Loss: 0.8678744267805156 Validation Accuracy: 0.8886780738830566


 64%|██████▍   | 801/1250 [11:25<1:35:54, 12.82s/it]

Epoch 2 Batch 801 Training Loss: 0.8679084831409241 Validation Accuracy: 0.8881788849830627


 72%|███████▏  | 901/1250 [12:45<1:14:11, 12.75s/it]

Epoch 2 Batch 901 Training Loss: 0.8680688499345895 Validation Accuracy: 0.8916733264923096


 80%|████████  | 1001/1250 [14:05<53:21, 12.86s/it]

Epoch 2 Batch 1001 Training Loss: 0.8680776421364014 Validation Accuracy: 0.8950678706169128


 88%|████████▊ | 1101/1250 [15:26<31:41, 12.76s/it]

Epoch 2 Batch 1101 Training Loss: 0.867947556342351 Validation Accuracy: 0.8904752135276794


 96%|█████████▌| 1201/1250 [16:46<10:28, 12.84s/it]

Epoch 2 Batch 1201 Training Loss: 0.8679239719138356 Validation Accuracy: 0.8944688439369202


100%|██████████| 1250/1250 [17:05<00:00,  1.22it/s]


Epoch 2 Avg Training Loss: 0.8679589164733886


  0%|          | 1/1250 [00:42<14:36:17, 42.10s/it]

Epoch 3 Batch 1 Training Loss: 0.8513171672821045 Validation Accuracy: 0.8842851519584656


  8%|▊         | 101/1250 [02:02<4:05:01, 12.80s/it]

Epoch 3 Batch 101 Training Loss: 0.8644596173031496 Validation Accuracy: 0.8924720287322998


 16%|█▌        | 201/1250 [03:22<3:43:39, 12.79s/it]

Epoch 3 Batch 201 Training Loss: 0.8658819071095974 Validation Accuracy: 0.8926717042922974


 24%|██▍       | 301/1250 [04:42<3:21:59, 12.77s/it]

Epoch 3 Batch 301 Training Loss: 0.8655891137265683 Validation Accuracy: 0.894568681716919


 32%|███▏      | 401/1250 [06:03<3:01:28, 12.82s/it]

Epoch 3 Batch 401 Training Loss: 0.865236447190406 Validation Accuracy: 0.8952675461769104


 40%|████      | 501/1250 [07:23<2:39:20, 12.76s/it]

Epoch 3 Batch 501 Training Loss: 0.8651742725791094 Validation Accuracy: 0.8936701416969299


 48%|████▊     | 601/1250 [08:43<2:18:19, 12.79s/it]

Epoch 3 Batch 601 Training Loss: 0.8651969567709874 Validation Accuracy: 0.8924720287322998


 56%|█████▌    | 701/1250 [10:04<1:56:58, 12.78s/it]

Epoch 3 Batch 701 Training Loss: 0.8652396316875235 Validation Accuracy: 0.8968650102615356


 64%|██████▍   | 801/1250 [11:24<1:35:40, 12.78s/it]

Epoch 3 Batch 801 Training Loss: 0.8651922160617719 Validation Accuracy: 0.8966653347015381


 72%|███████▏  | 901/1250 [12:44<1:14:29, 12.81s/it]

Epoch 3 Batch 901 Training Loss: 0.8651605140620411 Validation Accuracy: 0.8967651724815369


 80%|████████  | 1001/1250 [14:05<53:01, 12.78s/it]

Epoch 3 Batch 1001 Training Loss: 0.8651390761643142 Validation Accuracy: 0.8987619876861572


 88%|████████▊ | 1101/1250 [15:25<31:52, 12.84s/it]

Epoch 3 Batch 1101 Training Loss: 0.8651117630485617 Validation Accuracy: 0.8968650102615356


 96%|█████████▌| 1201/1250 [16:45<10:27, 12.80s/it]

Epoch 3 Batch 1201 Training Loss: 0.8652255740590536 Validation Accuracy: 0.8972643613815308


100%|██████████| 1250/1250 [17:05<00:00,  1.22it/s]


Epoch 3 Avg Training Loss: 0.8652304019927979


  0%|          | 1/1250 [00:42<14:35:51, 42.08s/it]

Epoch 4 Batch 1 Training Loss: 0.8615022897720337 Validation Accuracy: 0.8938698172569275


  8%|▊         | 101/1250 [02:02<4:04:11, 12.75s/it]

Epoch 4 Batch 101 Training Loss: 0.8638060547337674 Validation Accuracy: 0.8972643613815308


 16%|█▌        | 201/1250 [03:22<3:43:37, 12.79s/it]

Epoch 4 Batch 201 Training Loss: 0.8632082639642022 Validation Accuracy: 0.8964656591415405


 24%|██▍       | 301/1250 [04:42<3:21:57, 12.77s/it]

Epoch 4 Batch 301 Training Loss: 0.8631889687424086 Validation Accuracy: 0.8958665728569031


 32%|███▏      | 401/1250 [06:02<3:00:57, 12.79s/it]

Epoch 4 Batch 401 Training Loss: 0.8635744538687709 Validation Accuracy: 0.8961661458015442


 40%|████      | 501/1250 [07:22<2:39:17, 12.76s/it]

Epoch 4 Batch 501 Training Loss: 0.8637627672530458 Validation Accuracy: 0.8969648480415344


 48%|████▊     | 601/1250 [08:43<2:18:17, 12.79s/it]

Epoch 4 Batch 601 Training Loss: 0.8636905527154538 Validation Accuracy: 0.8979632258415222


 56%|█████▌    | 701/1250 [10:03<1:56:55, 12.78s/it]

Epoch 4 Batch 701 Training Loss: 0.8635141812105491 Validation Accuracy: 0.8991613388061523


 64%|██████▍   | 801/1250 [11:23<1:35:45, 12.80s/it]

Epoch 4 Batch 801 Training Loss: 0.863527133521367 Validation Accuracy: 0.8943690061569214


 72%|███████▏  | 901/1250 [12:44<1:14:13, 12.76s/it]

Epoch 4 Batch 901 Training Loss: 0.8636213595277065 Validation Accuracy: 0.8979632258415222


 80%|████████  | 1001/1250 [14:04<53:12, 12.82s/it]

Epoch 4 Batch 1001 Training Loss: 0.8635376983589226 Validation Accuracy: 0.8978633880615234


 88%|████████▊ | 1101/1250 [15:24<31:40, 12.75s/it]

Epoch 4 Batch 1101 Training Loss: 0.8634505014761701 Validation Accuracy: 0.8983625769615173


 96%|█████████▌| 1201/1250 [16:44<10:27, 12.81s/it]

Epoch 4 Batch 1201 Training Loss: 0.8634899902304046 Validation Accuracy: 0.8984624147415161


100%|██████████| 1250/1250 [17:04<00:00,  1.22it/s]


Epoch 4 Avg Training Loss: 0.8634932363033294


  0%|          | 1/1250 [00:41<14:32:37, 41.92s/it]

Epoch 5 Batch 1 Training Loss: 0.8579273223876953 Validation Accuracy: 0.899760365486145


  8%|▊         | 101/1250 [02:02<4:04:33, 12.77s/it]

Epoch 5 Batch 101 Training Loss: 0.8602998575361649 Validation Accuracy: 0.8979632258415222


 16%|█▌        | 201/1250 [03:22<3:43:20, 12.77s/it]

Epoch 5 Batch 201 Training Loss: 0.8600475983833199 Validation Accuracy: 0.8977635502815247


 24%|██▍       | 301/1250 [04:42<3:22:10, 12.78s/it]

Epoch 5 Batch 301 Training Loss: 0.8603018908405621 Validation Accuracy: 0.8999600410461426


 32%|███▏      | 401/1250 [06:02<3:00:49, 12.78s/it]

Epoch 5 Batch 401 Training Loss: 0.8609618653085761 Validation Accuracy: 0.899760365486145


 40%|████      | 501/1250 [07:23<2:39:37, 12.79s/it]

Epoch 5 Batch 501 Training Loss: 0.8612115404563035 Validation Accuracy: 0.8977635502815247


 48%|████▊     | 601/1250 [08:43<2:18:05, 12.77s/it]

Epoch 5 Batch 601 Training Loss: 0.8614473262563125 Validation Accuracy: 0.8982627391815186


 56%|█████▌    | 701/1250 [10:03<1:57:02, 12.79s/it]

Epoch 5 Batch 701 Training Loss: 0.8615091079992167 Validation Accuracy: 0.8926717042922974


 64%|██████▍   | 801/1250 [11:23<1:35:31, 12.77s/it]

Epoch 5 Batch 801 Training Loss: 0.8617752650555005 Validation Accuracy: 0.8978633880615234


 72%|███████▏  | 901/1250 [12:43<1:14:22, 12.79s/it]

Epoch 5 Batch 901 Training Loss: 0.8619542481102769 Validation Accuracy: 0.901457667350769


 80%|████████  | 1001/1250 [14:04<53:03, 12.79s/it]

Epoch 5 Batch 1001 Training Loss: 0.8619793604661178 Validation Accuracy: 0.8994608521461487


 88%|████████▊ | 1101/1250 [15:24<31:42, 12.77s/it]

Epoch 5 Batch 1101 Training Loss: 0.8620000744710501 Validation Accuracy: 0.9013578295707703


 96%|█████████▌| 1201/1250 [16:44<10:26, 12.78s/it]

Epoch 5 Batch 1201 Training Loss: 0.8621205289993159 Validation Accuracy: 0.9001597166061401


100%|██████████| 1250/1250 [17:03<00:00,  1.22it/s]


Epoch 5 Avg Training Loss: 0.8621575971603394


  0%|          | 1/1250 [00:41<14:33:51, 41.98s/it]

Epoch 6 Batch 1 Training Loss: 0.8613770008087158 Validation Accuracy: 0.8976637125015259


  8%|▊         | 101/1250 [02:02<4:04:21, 12.76s/it]

Epoch 6 Batch 101 Training Loss: 0.8600472229542119 Validation Accuracy: 0.9012579917907715


 16%|█▌        | 201/1250 [03:22<3:43:11, 12.77s/it]

Epoch 6 Batch 201 Training Loss: 0.8602247508011054 Validation Accuracy: 0.8995606899261475


 24%|██▍       | 301/1250 [04:42<3:21:56, 12.77s/it]

Epoch 6 Batch 301 Training Loss: 0.8603951002276221 Validation Accuracy: 0.8960662484169006


 32%|███▏      | 401/1250 [06:02<3:00:26, 12.75s/it]

Epoch 6 Batch 401 Training Loss: 0.8605177112648315 Validation Accuracy: 0.8969648480415344


 40%|████      | 501/1250 [07:22<2:39:46, 12.80s/it]

Epoch 6 Batch 501 Training Loss: 0.8607881351145442 Validation Accuracy: 0.900658905506134


 48%|████▊     | 601/1250 [08:42<2:17:49, 12.74s/it]

Epoch 6 Batch 601 Training Loss: 0.8608232724686431 Validation Accuracy: 0.9018570184707642


 56%|█████▌    | 701/1250 [10:03<1:57:14, 12.81s/it]

Epoch 6 Batch 701 Training Loss: 0.8608729993225674 Validation Accuracy: 0.8994608521461487


 64%|██████▍   | 801/1250 [11:23<1:35:20, 12.74s/it]

Epoch 6 Batch 801 Training Loss: 0.860983180121685 Validation Accuracy: 0.8998602032661438


 72%|███████▏  | 901/1250 [12:43<1:14:28, 12.80s/it]

Epoch 6 Batch 901 Training Loss: 0.8611473919416506 Validation Accuracy: 0.900658905506134


 80%|████████  | 1001/1250 [14:03<52:56, 12.76s/it]

Epoch 6 Batch 1001 Training Loss: 0.8612604714654661 Validation Accuracy: 0.8990615010261536


 88%|████████▊ | 1101/1250 [15:24<31:50, 12.82s/it]

Epoch 6 Batch 1101 Training Loss: 0.8612682816118246 Validation Accuracy: 0.8999600410461426


 92%|█████████▏| 1151/1250 [15:43<01:21,  1.22it/s]


KeyboardInterrupt: ignored

In [None]:
torch.save(model.state_dict(), '/content/drive/MyDrive/cv_model.pth')