In [2]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

In [29]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import torch
import torchvision
from torchvision import datasets
from torchvision import transforms as T # for simplifying the transforms
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, sampler, random_split
from torchvision import models

In [4]:
## Now, we import timm, torchvision image models
!pip install timm # kaggle doesnt have it installed by default
import timm
from timm.loss import LabelSmoothingCrossEntropy # This is better than normal nn.CrossEntropyLoss

Collecting timm
  Downloading timm-0.9.5-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub (from timm)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors (from timm)
  Downloading safetensors-0.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m13.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: safetensors, huggingface-hub, timm
Successfully installed huggingface-hub-0.16.4 safetensors-0.3.2 timm-0.9.5


In [5]:
# remove warnings
import warnings
warnings.filterwarnings("ignore")

In [6]:
import matplotlib.pyplot as plt
%matplotlib inline

In [7]:
import sys
from tqdm import tqdm
import time
import copy

In [8]:
def get_classes(data_dir):
    all_data = datasets.ImageFolder(data_dir)
    return all_data.classes

In [9]:
def get_data_loaders(data_dir, batch_size, train = False):
    if train:
        #train
        transform = T.Compose([
            T.RandomHorizontalFlip(),
            T.RandomVerticalFlip(),
            T.RandomApply(torch.nn.ModuleList([T.ColorJitter()]), p=0.25),
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # imagenet means
            T.RandomErasing(p=0.2, value='random')
        ])
        train_data = datasets.ImageFolder(os.path.join(data_dir, "train/"), transform = transform)
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return train_loader, len(train_data)
    else:
        # val/test
        transform = T.Compose([ # We dont need augmentation for test transforms
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # imagenet means
        ])
        val_data = datasets.ImageFolder(os.path.join(data_dir, "valid/"), transform=transform)
        test_data = datasets.ImageFolder(os.path.join(data_dir, "test/"), transform=transform)
        val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=True, num_workers=4)
        test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=4)
        return val_loader, test_loader, len(val_data), len(test_data)

In [30]:
dataset_path = "/content/drive/MyDrive/BikeProject/Tyre Component"

In [31]:
(train_loader, train_data_len) = get_data_loaders(dataset_path, 128, train=True)
(val_loader, test_loader, valid_data_len, test_data_len) = get_data_loaders(dataset_path, 32, train=False)

In [32]:
classes = get_classes("/content/drive/MyDrive/BikeProject/Tyre Component/train/")
print(classes, len(classes))

['Good', 'bad', 'moderate'] 3


In [33]:
dataloaders = {
    "train": train_loader,
    "val": val_loader
}
dataset_sizes = {
    "train": train_data_len,
    "val": valid_data_len
}

In [17]:
print(len(train_loader), len(val_loader), len(test_loader))

5 2 2


In [18]:
# now, for the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [34]:
model = torch.hub.load('facebookresearch/deit:main', 'deit_tiny_patch16_224', pretrained=True)

Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main


In [35]:
for param in model.parameters(): #freeze model
    param.requires_grad = False

n_inputs = model.head.in_features
model.head = nn.Sequential(
    nn.Linear(n_inputs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(classes))
)
model = model.to(device)
print(model.head)

Sequential(
  (0): Linear(in_features=192, out_features=512, bias=True)
  (1): ReLU()
  (2): Dropout(p=0.3, inplace=False)
  (3): Linear(in_features=512, out_features=3, bias=True)
)


In [36]:
criterion = LabelSmoothingCrossEntropy()
criterion = criterion.to(device)
optimizer = optim.Adam(model.head.parameters(), lr=0.001)

In [37]:
# lr scheduler
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.97)

In [38]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=100):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print("-"*10)

        for phase in ['train', 'val']: # We do training and validation phase per epoch
            if phase == 'train':
                model.train() # model to training mode
            else:
                model.eval() # model to evaluate

            running_loss = 0.0
            running_corrects = 0.0

            for inputs, labels in tqdm(dataloaders[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'): # no autograd makes validation go faster
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1) # used for accuracy
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step() # step at end of epoch

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc =  running_corrects.double() / dataset_sizes[phase]

            print("{} Loss: {:.4f} Acc: {:.4f}".format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict()) # keep the best validation accuracy model
        print()
    time_elapsed = time.time() - since # slight error
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print("Best Val Acc: {:.4f}".format(best_acc))

    model.load_state_dict(best_model_wts)
    return model

In [39]:
model_ft = train_model(model, criterion, optimizer, exp_lr_scheduler) # now it is a lot faster
# I will come back after 10 epochs

Epoch 0/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.52s/it]


train Loss: 0.9943 Acc: 0.5023


100%|██████████| 2/2 [00:04<00:00,  2.26s/it]


val Loss: 1.6749 Acc: 0.1746

Epoch 1/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


train Loss: 0.7501 Acc: 0.7317


100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


val Loss: 2.1607 Acc: 0.2222

Epoch 2/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.63s/it]


train Loss: 0.7563 Acc: 0.7083


100%|██████████| 2/2 [00:02<00:00,  1.19s/it]


val Loss: 2.2924 Acc: 0.2063

Epoch 3/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.27s/it]


train Loss: 0.6854 Acc: 0.7582


100%|██████████| 2/2 [00:02<00:00,  1.22s/it]


val Loss: 2.2833 Acc: 0.2063

Epoch 4/99
----------


100%|██████████| 6/6 [00:28<00:00,  4.71s/it]


train Loss: 0.6571 Acc: 0.7800


100%|██████████| 2/2 [00:02<00:00,  1.10s/it]


val Loss: 2.3029 Acc: 0.2381

Epoch 5/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.27s/it]


train Loss: 0.6367 Acc: 0.8019


100%|██████████| 2/2 [00:02<00:00,  1.25s/it]


val Loss: 2.3058 Acc: 0.2063

Epoch 6/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.28s/it]


train Loss: 0.6353 Acc: 0.8034


100%|██████████| 2/2 [00:02<00:00,  1.26s/it]


val Loss: 2.3143 Acc: 0.1905

Epoch 7/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.28s/it]


train Loss: 0.6274 Acc: 0.7972


100%|██████████| 2/2 [00:02<00:00,  1.15s/it]


val Loss: 2.2727 Acc: 0.2540

Epoch 8/99
----------


100%|██████████| 6/6 [00:28<00:00,  4.72s/it]


train Loss: 0.6314 Acc: 0.7878


100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


val Loss: 2.3061 Acc: 0.2540

Epoch 9/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.95s/it]


train Loss: 0.6067 Acc: 0.8190


100%|██████████| 2/2 [00:01<00:00,  1.48it/s]


val Loss: 2.3673 Acc: 0.1746

Epoch 10/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.56s/it]


train Loss: 0.6005 Acc: 0.8253


100%|██████████| 2/2 [00:01<00:00,  1.40it/s]


val Loss: 2.3094 Acc: 0.2698

Epoch 11/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.90s/it]


train Loss: 0.6454 Acc: 0.8003


100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


val Loss: 2.1664 Acc: 0.2698

Epoch 12/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.63s/it]


train Loss: 0.5924 Acc: 0.8268


100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


val Loss: 2.0808 Acc: 0.1587

Epoch 13/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


train Loss: 0.5906 Acc: 0.8502


100%|██████████| 2/2 [00:02<00:00,  1.26s/it]


val Loss: 2.1465 Acc: 0.1746

Epoch 14/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.67s/it]


train Loss: 0.5850 Acc: 0.8424


100%|██████████| 2/2 [00:01<00:00,  1.51it/s]


val Loss: 2.1823 Acc: 0.2381

Epoch 15/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


train Loss: 0.5742 Acc: 0.8346


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.2005 Acc: 0.3016

Epoch 16/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


train Loss: 0.5899 Acc: 0.8268


100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


val Loss: 2.2619 Acc: 0.2222

Epoch 17/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.42s/it]


train Loss: 0.5648 Acc: 0.8409


100%|██████████| 2/2 [00:01<00:00,  1.48it/s]


val Loss: 2.2596 Acc: 0.2222

Epoch 18/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


train Loss: 0.5626 Acc: 0.8471


100%|██████████| 2/2 [00:01<00:00,  1.46it/s]


val Loss: 2.2535 Acc: 0.2063

Epoch 19/99
----------


100%|██████████| 6/6 [00:28<00:00,  4.81s/it]


train Loss: 0.5376 Acc: 0.8721


100%|██████████| 2/2 [00:01<00:00,  1.46it/s]


val Loss: 2.2286 Acc: 0.2381

Epoch 20/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.41s/it]


train Loss: 0.5529 Acc: 0.8627


100%|██████████| 2/2 [00:01<00:00,  1.33it/s]


val Loss: 2.2026 Acc: 0.2698

Epoch 21/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


train Loss: 0.5449 Acc: 0.8736


100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


val Loss: 2.2097 Acc: 0.1905

Epoch 22/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


train Loss: 0.5378 Acc: 0.8658


100%|██████████| 2/2 [00:01<00:00,  1.39it/s]


val Loss: 2.1700 Acc: 0.2063

Epoch 23/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.44s/it]


train Loss: 0.5380 Acc: 0.8690


100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


val Loss: 2.1173 Acc: 0.2381

Epoch 24/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.84s/it]


train Loss: 0.5576 Acc: 0.8627


100%|██████████| 2/2 [00:02<00:00,  1.40s/it]


val Loss: 2.1473 Acc: 0.2698

Epoch 25/99
----------


100%|██████████| 6/6 [00:28<00:00,  4.82s/it]


train Loss: 0.5422 Acc: 0.8768


100%|██████████| 2/2 [00:02<00:00,  1.22s/it]


val Loss: 2.1980 Acc: 0.2540

Epoch 26/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.26s/it]


train Loss: 0.5476 Acc: 0.8627


100%|██████████| 2/2 [00:02<00:00,  1.29s/it]


val Loss: 2.2408 Acc: 0.2222

Epoch 27/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.27s/it]


train Loss: 0.5364 Acc: 0.8705


100%|██████████| 2/2 [00:02<00:00,  1.22s/it]


val Loss: 2.2632 Acc: 0.2857

Epoch 28/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.36s/it]


train Loss: 0.5165 Acc: 0.8955


100%|██████████| 2/2 [00:02<00:00,  1.05s/it]


val Loss: 2.3953 Acc: 0.1746

Epoch 29/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.43s/it]


train Loss: 0.5543 Acc: 0.8612


100%|██████████| 2/2 [00:01<00:00,  1.29it/s]


val Loss: 2.4550 Acc: 0.1746

Epoch 30/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


train Loss: 0.5126 Acc: 0.8939


100%|██████████| 2/2 [00:01<00:00,  1.45it/s]


val Loss: 2.4373 Acc: 0.2381

Epoch 31/99
----------


100%|██████████| 6/6 [00:30<00:00,  5.04s/it]


train Loss: 0.5165 Acc: 0.8924


100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


val Loss: 2.4120 Acc: 0.2381

Epoch 32/99
----------


100%|██████████| 6/6 [00:32<00:00,  5.37s/it]


train Loss: 0.5077 Acc: 0.8939


100%|██████████| 2/2 [00:01<00:00,  1.34it/s]


val Loss: 2.3544 Acc: 0.2063

Epoch 33/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.64s/it]


train Loss: 0.4914 Acc: 0.9048


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.2835 Acc: 0.2222

Epoch 34/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.60s/it]


train Loss: 0.4872 Acc: 0.9189


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.2322 Acc: 0.2222

Epoch 35/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.63s/it]


train Loss: 0.4844 Acc: 0.9251


100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


val Loss: 2.2256 Acc: 0.2540

Epoch 36/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.90s/it]


train Loss: 0.4720 Acc: 0.9173


100%|██████████| 2/2 [00:01<00:00,  1.47it/s]


val Loss: 2.2364 Acc: 0.2381

Epoch 37/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


train Loss: 0.4754 Acc: 0.9236


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.2248 Acc: 0.2381

Epoch 38/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.63s/it]


train Loss: 0.4784 Acc: 0.9142


100%|██████████| 2/2 [00:02<00:00,  1.13s/it]


val Loss: 2.2196 Acc: 0.2540

Epoch 39/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


train Loss: 0.4726 Acc: 0.9095


100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


val Loss: 2.2453 Acc: 0.2540

Epoch 40/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.99s/it]


train Loss: 0.4974 Acc: 0.8986


100%|██████████| 2/2 [00:01<00:00,  1.39it/s]


val Loss: 2.2731 Acc: 0.2698

Epoch 41/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.59s/it]


train Loss: 0.4817 Acc: 0.9111


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.3032 Acc: 0.2698

Epoch 42/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.90s/it]


train Loss: 0.4896 Acc: 0.9064


100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


val Loss: 2.4228 Acc: 0.1905

Epoch 43/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.60s/it]


train Loss: 0.5124 Acc: 0.8908


100%|██████████| 2/2 [00:01<00:00,  1.36it/s]


val Loss: 2.4173 Acc: 0.2222

Epoch 44/99
----------


100%|██████████| 6/6 [00:28<00:00,  4.68s/it]


train Loss: 0.5080 Acc: 0.8970


100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


val Loss: 2.3993 Acc: 0.2222

Epoch 45/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.95s/it]


train Loss: 0.4878 Acc: 0.9111


100%|██████████| 2/2 [00:01<00:00,  1.45it/s]


val Loss: 2.4145 Acc: 0.1746

Epoch 46/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


train Loss: 0.4881 Acc: 0.9126


100%|██████████| 2/2 [00:01<00:00,  1.40it/s]


val Loss: 2.3721 Acc: 0.2540

Epoch 47/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.85s/it]


train Loss: 0.4921 Acc: 0.9126


100%|██████████| 2/2 [00:02<00:00,  1.32s/it]


val Loss: 2.3429 Acc: 0.2540

Epoch 48/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.63s/it]


train Loss: 0.4844 Acc: 0.9080


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.4210 Acc: 0.2063

Epoch 49/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


train Loss: 0.4902 Acc: 0.9095


100%|██████████| 2/2 [00:01<00:00,  1.46it/s]


val Loss: 2.3755 Acc: 0.2381

Epoch 50/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.47s/it]


train Loss: 0.4579 Acc: 0.9376


100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


val Loss: 2.3003 Acc: 0.2540

Epoch 51/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


train Loss: 0.4652 Acc: 0.9329


100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


val Loss: 2.2740 Acc: 0.2381

Epoch 52/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.66s/it]


train Loss: 0.4522 Acc: 0.9392


100%|██████████| 2/2 [00:02<00:00,  1.13s/it]


val Loss: 2.2619 Acc: 0.2381

Epoch 53/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.39s/it]


train Loss: 0.4446 Acc: 0.9438


100%|██████████| 2/2 [00:02<00:00,  1.23s/it]


val Loss: 2.2566 Acc: 0.2698

Epoch 54/99
----------


100%|██████████| 6/6 [00:28<00:00,  4.74s/it]


train Loss: 0.4493 Acc: 0.9470


100%|██████████| 2/2 [00:01<00:00,  1.33it/s]


val Loss: 2.2739 Acc: 0.2381

Epoch 55/99
----------


100%|██████████| 6/6 [00:28<00:00,  4.80s/it]


train Loss: 0.4476 Acc: 0.9345


100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


val Loss: 2.3135 Acc: 0.2222

Epoch 56/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.43s/it]


train Loss: 0.4329 Acc: 0.9485


100%|██████████| 2/2 [00:01<00:00,  1.28it/s]


val Loss: 2.3133 Acc: 0.2222

Epoch 57/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.46s/it]


train Loss: 0.4451 Acc: 0.9454


100%|██████████| 2/2 [00:01<00:00,  1.04it/s]


val Loss: 2.2915 Acc: 0.2698

Epoch 58/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.33s/it]


train Loss: 0.4345 Acc: 0.9501


100%|██████████| 2/2 [00:02<00:00,  1.16s/it]


val Loss: 2.2834 Acc: 0.2540

Epoch 59/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.28s/it]


train Loss: 0.4434 Acc: 0.9516


100%|██████████| 2/2 [00:02<00:00,  1.25s/it]


val Loss: 2.3287 Acc: 0.1905

Epoch 60/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.66s/it]


train Loss: 0.4594 Acc: 0.9236


100%|██████████| 2/2 [00:02<00:00,  1.05s/it]


val Loss: 2.3063 Acc: 0.2698

Epoch 61/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.36s/it]


train Loss: 0.4486 Acc: 0.9454


100%|██████████| 2/2 [00:02<00:00,  1.23s/it]


val Loss: 2.3203 Acc: 0.2857

Epoch 62/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.30s/it]


train Loss: 0.4460 Acc: 0.9470


100%|██████████| 2/2 [00:02<00:00,  1.26s/it]


val Loss: 2.3163 Acc: 0.2857

Epoch 63/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.25s/it]


train Loss: 0.4446 Acc: 0.9392


100%|██████████| 2/2 [00:02<00:00,  1.30s/it]


val Loss: 2.3002 Acc: 0.2540

Epoch 64/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.28s/it]


train Loss: 0.4424 Acc: 0.9376


100%|██████████| 2/2 [00:02<00:00,  1.12s/it]


val Loss: 2.2967 Acc: 0.2381

Epoch 65/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.34s/it]


train Loss: 0.4390 Acc: 0.9407


100%|██████████| 2/2 [00:01<00:00,  1.20it/s]


val Loss: 2.2760 Acc: 0.2698

Epoch 66/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.87s/it]


train Loss: 0.4456 Acc: 0.9407


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.2702 Acc: 0.2857

Epoch 67/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.57s/it]


train Loss: 0.4495 Acc: 0.9345


100%|██████████| 2/2 [00:01<00:00,  1.45it/s]


val Loss: 2.2661 Acc: 0.2698

Epoch 68/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.60s/it]


train Loss: 0.4455 Acc: 0.9392


100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


val Loss: 2.2679 Acc: 0.2381

Epoch 69/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.59s/it]


train Loss: 0.4562 Acc: 0.9329


100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


val Loss: 2.3131 Acc: 0.2222

Epoch 70/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


train Loss: 0.4758 Acc: 0.9080


100%|██████████| 2/2 [00:01<00:00,  1.39it/s]


val Loss: 2.2100 Acc: 0.2540

Epoch 71/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.95s/it]


train Loss: 0.4355 Acc: 0.9516


100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


val Loss: 2.1599 Acc: 0.2857

Epoch 72/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.99s/it]


train Loss: 0.4723 Acc: 0.9329


100%|██████████| 2/2 [00:01<00:00,  1.40it/s]


val Loss: 2.1449 Acc: 0.2698

Epoch 73/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


train Loss: 0.4401 Acc: 0.9454


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.1661 Acc: 0.2540

Epoch 74/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.85s/it]


train Loss: 0.4286 Acc: 0.9532


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.2536 Acc: 0.2381

Epoch 75/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.55s/it]


train Loss: 0.4344 Acc: 0.9470


100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


val Loss: 2.2648 Acc: 0.2540

Epoch 76/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


train Loss: 0.4269 Acc: 0.9516


100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


val Loss: 2.2469 Acc: 0.2698

Epoch 77/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.45s/it]


train Loss: 0.4272 Acc: 0.9501


100%|██████████| 2/2 [00:01<00:00,  1.47it/s]


val Loss: 2.2475 Acc: 0.2698

Epoch 78/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.92s/it]


train Loss: 0.4179 Acc: 0.9470


100%|██████████| 2/2 [00:01<00:00,  1.39it/s]


val Loss: 2.2710 Acc: 0.2540

Epoch 79/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.96s/it]


train Loss: 0.4266 Acc: 0.9485


100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


val Loss: 2.2797 Acc: 0.2540

Epoch 80/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


train Loss: 0.4201 Acc: 0.9626


100%|██████████| 2/2 [00:01<00:00,  1.45it/s]


val Loss: 2.2737 Acc: 0.2540

Epoch 81/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.86s/it]


train Loss: 0.4199 Acc: 0.9563


100%|██████████| 2/2 [00:01<00:00,  1.46it/s]


val Loss: 2.2780 Acc: 0.2540

Epoch 82/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.50s/it]


train Loss: 0.4075 Acc: 0.9672


100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


val Loss: 2.2781 Acc: 0.2698

Epoch 83/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.49s/it]


train Loss: 0.4114 Acc: 0.9641


100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


val Loss: 2.2999 Acc: 0.2698

Epoch 84/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.91s/it]


train Loss: 0.4185 Acc: 0.9688


100%|██████████| 2/2 [00:01<00:00,  1.42it/s]


val Loss: 2.2735 Acc: 0.2698

Epoch 85/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


train Loss: 0.4212 Acc: 0.9657


100%|██████████| 2/2 [00:01<00:00,  1.13it/s]


val Loss: 2.2508 Acc: 0.2857

Epoch 86/99
----------


100%|██████████| 6/6 [00:28<00:00,  4.68s/it]


train Loss: 0.4029 Acc: 0.9704


100%|██████████| 2/2 [00:01<00:00,  1.22it/s]


val Loss: 2.2479 Acc: 0.2857

Epoch 87/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.87s/it]


train Loss: 0.4126 Acc: 0.9641


100%|██████████| 2/2 [00:01<00:00,  1.40it/s]


val Loss: 2.2478 Acc: 0.2857

Epoch 88/99
----------


100%|██████████| 6/6 [00:29<00:00,  4.90s/it]


train Loss: 0.4120 Acc: 0.9548


100%|██████████| 2/2 [00:01<00:00,  1.38it/s]


val Loss: 2.2623 Acc: 0.2857

Epoch 89/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.60s/it]


train Loss: 0.4161 Acc: 0.9516


100%|██████████| 2/2 [00:02<00:00,  1.25s/it]


val Loss: 2.2774 Acc: 0.2540

Epoch 90/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.54s/it]


train Loss: 0.3980 Acc: 0.9782


100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


val Loss: 2.3008 Acc: 0.2540

Epoch 91/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.51s/it]


train Loss: 0.4156 Acc: 0.9641


100%|██████████| 2/2 [00:01<00:00,  1.39it/s]


val Loss: 2.2922 Acc: 0.2540

Epoch 92/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.53s/it]


train Loss: 0.4056 Acc: 0.9672


100%|██████████| 2/2 [00:01<00:00,  1.40it/s]


val Loss: 2.2898 Acc: 0.3016

Epoch 93/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.48s/it]


train Loss: 0.4230 Acc: 0.9516


100%|██████████| 2/2 [00:01<00:00,  1.44it/s]


val Loss: 2.3003 Acc: 0.3016

Epoch 94/99
----------


100%|██████████| 6/6 [00:25<00:00,  4.32s/it]


train Loss: 0.4154 Acc: 0.9563


100%|██████████| 2/2 [00:01<00:00,  1.10it/s]


val Loss: 2.3188 Acc: 0.2857

Epoch 95/99
----------


100%|██████████| 6/6 [00:28<00:00,  4.74s/it]


train Loss: 0.4235 Acc: 0.9657


100%|██████████| 2/2 [00:02<00:00,  1.24s/it]


val Loss: 2.4178 Acc: 0.2381

Epoch 96/99
----------


100%|██████████| 6/6 [00:27<00:00,  4.58s/it]


train Loss: 0.4460 Acc: 0.9423


100%|██████████| 2/2 [00:01<00:00,  1.43it/s]


val Loss: 2.4097 Acc: 0.2381

Epoch 97/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.46s/it]


train Loss: 0.4533 Acc: 0.9360


100%|██████████| 2/2 [00:01<00:00,  1.35it/s]


val Loss: 2.4386 Acc: 0.2222

Epoch 98/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.45s/it]


train Loss: 0.4373 Acc: 0.9516


100%|██████████| 2/2 [00:01<00:00,  1.37it/s]


val Loss: 2.3695 Acc: 0.2540

Epoch 99/99
----------


100%|██████████| 6/6 [00:26<00:00,  4.42s/it]


train Loss: 0.4019 Acc: 0.9735


100%|██████████| 2/2 [00:01<00:00,  1.13it/s]

val Loss: 2.3564 Acc: 0.2540

Training complete in 48m 56s
Best Val Acc: 0.3016





In [40]:
test_loss = 0.0
class_correct = list(0 for i in range(len(classes)))
class_total = list(0 for i in range(len(classes)))
model.eval()

for data, target in tqdm(test_loader):
    data, target = data.to(device), target.to(device)
    with torch.no_grad(): # turn off autograd for faster testing
        output = model(data)
        loss = criterion(output, target)
    test_loss = loss.item() * data.size(0)
    _, pred = torch.max(output, 1)
    correct_tensor = pred.eq(target.data.view_as(pred))
    correct = np.squeeze(correct_tensor.cpu().numpy())
    if len(target) == 32:
        for i in range(32):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

test_loss = test_loss / test_data_len
print('Test Loss: {:.4f}'.format(test_loss))
for i in range(len(classes)):
    if class_total[i] > 0:
        print("Test Accuracy of %5s: %2d%% (%2d/%2d)" % (
            classes[i], 100*class_correct[i]/class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])
        ))
    else:
        print("Test accuracy of %5s: NA" % (classes[i]))
print("Test Accuracy of %2d%% (%2d/%2d)" % (
            100*np.sum(class_correct)/np.sum(class_total), np.sum(class_correct), np.sum(class_total)
        ))

100%|██████████| 2/2 [00:03<00:00,  1.63s/it]

Test Loss: 0.7949
Test Accuracy of  Good:  0% ( 0/ 6)
Test Accuracy of   bad:  0% ( 0/12)
Test Accuracy of moderate: 85% (12/14)
Test Accuracy of 37% (12/32)





In [None]:
# our model earns 93% test accuracy, which is very high. lets save it
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model.cpu(), example)
traced_script_module.save("butterfly_deit_video.pt")

In [None]:
!pip install h5py


In [None]:
import h5py
import torch

# ... your previous code ...

# After testing, save the model in HDF5 format
model_filename = 'deit_model.h5'

with h5py.File(model_filename, 'w') as f:
    for name, param in model.named_parameters():
        f.create_dataset(name, data=param.cpu().numpy())

print("Model saved in HDF5 format successfully.")

# ... rest of your code ...

