In [1]:
import torch.nn as nn
import torch
from torch import Tensor
from typing import Type
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import torch.optim as optim
import numpy as np
import random
import math
import torch.nn.functional as F
from aircraft import Aircraft

In [2]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
np.random.seed(seed)
random.seed(seed)

In [3]:
class BasicBlock(nn.Module):
    def __init__(
        self, 
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        expansion: int = 1,
        downsample: nn.Module = None
    ) -> None:
        super(BasicBlock, self).__init__()
        self.expansion = expansion
        self.downsample = downsample
        self.conv1 = nn.Conv2d(
            in_channels, 
            out_channels, 
            kernel_size=3, 
            stride=stride, 
            padding=1,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(
            out_channels, 
            out_channels*self.expansion, 
            kernel_size=3, 
            padding=1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels*self.expansion)

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return  out


In [4]:
class ResNet(nn.Module):
    def __init__(
        self, 
        img_channels: int,
        num_layers: int,
        block: Type[BasicBlock]
    ) -> None:
        super(ResNet, self).__init__()
        if num_layers == 18:
            layers = [2, 2, 2, 2]
            self.expansion = 1
        
        self.in_channels = 64
        self.conv1 = nn.Conv2d(
            in_channels=img_channels,
            out_channels=self.in_channels,
            kernel_size=7, 
            stride=2,
            padding=3,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.final_downsample =  nn.Sequential(
                nn.Conv2d(
                    512, 
                    128,
                    kernel_size=1,
                    stride=1,
                    bias=False 
                ),
                nn.BatchNorm2d(128)
            )
        # self.fc = nn.Linear(512, 128)

    def _make_layer(
        self, 
        block: Type[BasicBlock],
        out_channels: int,
        blocks: int,
        stride: int = 1
    ) -> nn.Sequential:
        downsample = None
        if stride != 1:
            """
            This should pass from `layer2` to `layer4` or 
            when building ResNets50 and above. Section 3.3 of the paper
            Deep Residual Learning for Image Recognition
            (https://arxiv.org/pdf/1512.03385v1.pdf).
            """
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels, 
                    out_channels*self.expansion,
                    kernel_size=1,
                    stride=stride,
                    bias=False 
                ),
                nn.BatchNorm2d(out_channels * self.expansion),
            )
        layers = []
        layers.append(
            block(
                self.in_channels, out_channels, stride, self.expansion, downsample
            )
        )
        self.in_channels = out_channels * self.expansion

        for i in range(1, blocks):
            layers.append(block(
                self.in_channels,
                out_channels,
                expansion=self.expansion
            ))
        return nn.Sequential(*layers)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.final_downsample(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        return x

In [5]:
class KernelNN(nn.Module):
    def __init__(self):
        super(KernelNN, self).__init__()
        self.fc1 = nn.Linear(256, 128)
        self.bn = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.bn(x)
        x = F.relu(self.fc2(x))
        x = self.bn(x)
        x = torch.sigmoid(self.fc3(x))
        return x

In [6]:
class Net_NW_Head(nn.Module):

    def __init__(self, classes, device):
        self.device = device
        super(Net_NW_Head, self).__init__()
        self.rs18 = ResNet(img_channels=3, num_layers=18, block=BasicBlock)
        self.kernel_nn = KernelNN()
        self.classes = classes
        self.eps = 1e-5
        
    def set_support(self, support_element_imgs, support_labels):
        self.support_element_imgs = support_element_imgs
        self.support_labels_one_hot = torch.tensor(np.eye(self.classes)[support_labels], requires_grad=True).to(device)

    def apply_kernel(self, x):
        x_and_sup_concatenated = []
        for x_i in x:
            for s_i in self.support_elements:
                x_and_sup_concatenated.append(torch.cat([x_i, s_i]).unsqueeze(0))
        x_and_sup_concatenated = torch.cat(x_and_sup_concatenated)

        raw_distances = self.kernel_nn(x_and_sup_concatenated)
        raw_distances = torch.reshape(raw_distances, (len(x), len(self.support_elements)))

        raw_distances = raw_distances + self.eps
      
        distances = raw_distances / raw_distances.sum(dim=-1).unsqueeze(-1)
        result = []
        for i in range(distances.shape[0]):
            result.append((distances[i].unsqueeze(1) * self.support_labels_one_hot).sum(dim = 0).unsqueeze(0))
        result = torch.cat(result)
        return result

    def forward(self, x):

        x = self.rs18(x)
        self.support_elements = self.rs18(self.support_element_imgs)
        x = self.apply_kernel(x)
        x = x + self.eps
        x = torch.log(x)
        return x


In [7]:
plt.style.use('ggplot')
IMG_SIZE = 224

def get_data(batch_size=64):

    dataset_train = Aircraft('./data',
        train=True,
        download=True,
        transform=transforms.Compose([
              transforms.Resize((IMG_SIZE,IMG_SIZE)),
              transforms.RandomHorizontalFlip(p=0.5),
              transforms.RandomVerticalFlip(p=0.3),
              transforms.RandomRotation(degrees=(-15, 15)),
              transforms.ToTensor()
           ])
    )
    dataset_valid = Aircraft('./data',
         train=False,
         download=True,
         transform=transforms.Compose([
              transforms.Resize((IMG_SIZE,IMG_SIZE)),
              transforms.ToTensor()
           ])
    )
    train_loader = DataLoader(
        dataset_train, 
        batch_size=batch_size,
        shuffle=True
    )
    valid_loader = DataLoader(
        dataset_valid, 
        batch_size=batch_size,
        shuffle=False
    )
    return train_loader, valid_loader, dataset_train, dataset_valid
  

In [8]:
def save_plots(train_acc, valid_acc, train_loss, valid_loss, name=None):
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_acc, color='tab:blue', linestyle='-', 
        label='train accuracy'
    )
    plt.plot(
        valid_acc, color='tab:red', linestyle='-', 
        label='validataion accuracy'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(os.path.join('outputs', name+'_accuracy.png'))
    plt.figure(figsize=(10, 7))
    plt.plot(
        train_loss, color='tab:blue', linestyle='-', 
        label='train loss'
    )
    plt.plot(
        valid_loss, color='tab:red', linestyle='-', 
        label='validataion loss'
    )
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(os.path.join('outputs', name+'_loss.png'))

In [9]:
SUPPORT_SIZE = 3

def train(model, trainloader, optimizer, criterion, device, full_support_set):
    model.train()
    print('Training')
    train_running_loss = 0.0
    train_running_correct = 0
    counter = 0
    for i, data in tqdm(enumerate(trainloader), total=len(trainloader)):

        support_elements_temp = []
        support_labels = []
        for label in full_support_set.keys():
            support_elements_temp.extend(random.sample(full_support_set[label], SUPPORT_SIZE))
            support_labels.extend([label for i in range(SUPPORT_SIZE)])

        img_shape = support_elements_temp[0].shape
        support_elements_ng = torch.zeros((len(support_elements_temp), img_shape[0], img_shape[1], img_shape[2]))
        for i in range(len(support_elements_temp)):
          support_elements_ng[i] = support_elements_temp[i]

        support_elements = support_elements_ng.clone().detach().requires_grad_(True)
        support_elements = support_elements.to(device)

        counter += 1
        image, labels = data
        image = image.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()

        # support
        model.set_support(support_elements, support_labels)

        outputs = model(image)

        loss = criterion(outputs, labels)
        train_running_loss += loss.item()
        _, preds = torch.max(outputs.data, 1)
        train_running_correct += (preds == labels).sum().item()
        loss.backward()
        optimizer.step()
    
    epoch_loss = train_running_loss / counter
    epoch_acc = 100. * (train_running_correct / len(trainloader.dataset))
    return epoch_loss, epoch_acc

def validate(model, testloader, criterion, device, full_support_set):
    model.eval()
    print('Validation')
    valid_running_loss = 0.0
    valid_running_correct = 0
    counter = 0

    with torch.no_grad():
        for i, data in tqdm(enumerate(testloader), total=len(testloader)):

            support_elements_temp = []
            support_labels = []
            for label in full_support_set.keys():
                support_elements_temp.extend(random.sample(full_support_set[label], SUPPORT_SIZE))
                support_labels.extend([label for i in range(SUPPORT_SIZE)])

            img_shape = support_elements_temp[0].shape
            support_elements = torch.zeros((len(support_elements_temp), img_shape[0], img_shape[1], img_shape[2]))
            for i in range(len(support_elements_temp)):
                support_elements[i] = support_elements_temp[i]

            support_elements = support_elements.to(device)

            counter += 1
            image, labels = data
            image = image.to(device)
            labels = labels.to(device)

            # support
            model.set_support(support_elements, support_labels)
            outputs = model(image)
            loss = criterion(outputs, labels)
            valid_running_loss += loss.item()
            _, preds = torch.max(outputs.data, 1)
            valid_running_correct += (preds == labels).sum().item()
        
    epoch_loss = valid_running_loss / counter
    epoch_acc = 100. * (valid_running_correct / len(testloader.dataset))
    return epoch_loss, epoch_acc

In [10]:


epochs = 350
batch_size = 16
learning_rate = 0.001
momentum = 0.9
classes = 100

device = torch.device("cuda:4" if torch.cuda.is_available() else 'cpu')


In [11]:
train_loader, valid_loader, dataset_train, dataset_valid = get_data(batch_size=batch_size)

In [12]:
len(dataset_valid)

3333

In [13]:
full_support_set = dict()
for data, i in dataset_train:
    full_support_set.setdefault(i, []).append(data)

In [14]:
device

device(type='cuda', index=4)

In [15]:
model = Net_NW_Head(classes, device).to(device)
plot_name = 'nw_scratch'

In [16]:
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
criterion = F.nll_loss

In [17]:
train_loss, valid_loss = [], []
train_acc, valid_acc = [], []

for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_acc = train(
        model, 
        train_loader, 
        optimizer, 
        criterion,
        device,
        full_support_set
    )
    valid_epoch_loss, valid_epoch_acc = validate(
        model, 
        valid_loader, 
        criterion,
        device,
        full_support_set
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    train_acc.append(train_epoch_acc)
    valid_acc.append(valid_epoch_acc)
    print(learning_rate)
    torch.save(model.state_dict(), 'test_nw_fc_FGVCA2_w')
            
    print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
    print('-'*50)

[INFO]: Epoch 1 of 350
Training


100%|██████████| 417/417 [08:13<00:00,  1.18s/it]


Validation


100%|██████████| 209/209 [02:11<00:00,  1.59it/s]


0.001
Training loss: 4.533, training acc: 2.220
Validation loss: 4.500, validation acc: 2.670
--------------------------------------------------
[INFO]: Epoch 2 of 350
Training


100%|██████████| 417/417 [08:29<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.64it/s]


0.001
Training loss: 4.425, training acc: 3.060
Validation loss: 4.468, validation acc: 2.490
--------------------------------------------------
[INFO]: Epoch 3 of 350
Training


100%|██████████| 417/417 [08:28<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 4.285, training acc: 3.885
Validation loss: 4.274, validation acc: 3.510
--------------------------------------------------
[INFO]: Epoch 4 of 350
Training


100%|██████████| 417/417 [08:27<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 4.151, training acc: 4.590
Validation loss: 4.215, validation acc: 4.170
--------------------------------------------------
[INFO]: Epoch 5 of 350
Training


100%|██████████| 417/417 [08:29<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.64it/s]


0.001
Training loss: 4.054, training acc: 5.175
Validation loss: 4.494, validation acc: 1.740
--------------------------------------------------
[INFO]: Epoch 6 of 350
Training


100%|██████████| 417/417 [08:26<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.65it/s]


0.001
Training loss: 3.999, training acc: 5.250
Validation loss: 4.135, validation acc: 5.251
--------------------------------------------------
[INFO]: Epoch 7 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 3.906, training acc: 5.805
Validation loss: 4.068, validation acc: 6.541
--------------------------------------------------
[INFO]: Epoch 8 of 350
Training


100%|██████████| 417/417 [08:32<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 3.832, training acc: 7.500
Validation loss: 3.881, validation acc: 6.151
--------------------------------------------------
[INFO]: Epoch 9 of 350
Training


100%|██████████| 417/417 [08:29<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.62it/s]


0.001
Training loss: 3.709, training acc: 8.550
Validation loss: 4.045, validation acc: 6.331
--------------------------------------------------
[INFO]: Epoch 10 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 3.654, training acc: 8.295
Validation loss: 3.673, validation acc: 8.461
--------------------------------------------------
[INFO]: Epoch 11 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 3.558, training acc: 9.315
Validation loss: 3.441, validation acc: 11.521
--------------------------------------------------
[INFO]: Epoch 12 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.62it/s]


0.001
Training loss: 3.511, training acc: 10.964
Validation loss: 3.420, validation acc: 12.841
--------------------------------------------------
[INFO]: Epoch 13 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 3.428, training acc: 11.639
Validation loss: 3.307, validation acc: 14.461
--------------------------------------------------
[INFO]: Epoch 14 of 350
Training


100%|██████████| 417/417 [08:33<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 3.344, training acc: 13.049
Validation loss: 3.356, validation acc: 14.491
--------------------------------------------------
[INFO]: Epoch 15 of 350
Training


100%|██████████| 417/417 [08:34<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 3.295, training acc: 13.799
Validation loss: 3.186, validation acc: 15.632
--------------------------------------------------
[INFO]: Epoch 16 of 350
Training


100%|██████████| 417/417 [08:34<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 3.208, training acc: 15.314
Validation loss: 3.273, validation acc: 15.482
--------------------------------------------------
[INFO]: Epoch 17 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 3.093, training acc: 17.219
Validation loss: 3.047, validation acc: 19.832
--------------------------------------------------
[INFO]: Epoch 18 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:11<00:00,  1.59it/s]


0.001
Training loss: 3.006, training acc: 18.689
Validation loss: 2.988, validation acc: 22.892
--------------------------------------------------
[INFO]: Epoch 19 of 350
Training


100%|██████████| 417/417 [08:33<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 2.941, training acc: 19.169
Validation loss: 2.814, validation acc: 21.272
--------------------------------------------------
[INFO]: Epoch 20 of 350
Training


100%|██████████| 417/417 [08:27<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 2.828, training acc: 21.989
Validation loss: 2.969, validation acc: 20.162
--------------------------------------------------
[INFO]: Epoch 21 of 350
Training


100%|██████████| 417/417 [08:33<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 2.772, training acc: 22.439
Validation loss: 2.684, validation acc: 26.403
--------------------------------------------------
[INFO]: Epoch 22 of 350
Training


100%|██████████| 417/417 [08:34<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 2.705, training acc: 23.759
Validation loss: 2.762, validation acc: 25.083
--------------------------------------------------
[INFO]: Epoch 23 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 2.645, training acc: 25.364
Validation loss: 2.579, validation acc: 26.013
--------------------------------------------------
[INFO]: Epoch 24 of 350
Training


100%|██████████| 417/417 [08:35<00:00,  1.24s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 2.584, training acc: 26.729
Validation loss: 2.734, validation acc: 25.743
--------------------------------------------------
[INFO]: Epoch 25 of 350
Training


100%|██████████| 417/417 [08:28<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.62it/s]


0.001
Training loss: 2.529, training acc: 28.319
Validation loss: 2.672, validation acc: 27.453
--------------------------------------------------
[INFO]: Epoch 26 of 350
Training


100%|██████████| 417/417 [08:32<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:12<00:00,  1.58it/s]


0.001
Training loss: 2.461, training acc: 29.759
Validation loss: 2.483, validation acc: 31.143
--------------------------------------------------
[INFO]: Epoch 27 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.64it/s]


0.001
Training loss: 2.419, training acc: 30.688
Validation loss: 2.429, validation acc: 32.793
--------------------------------------------------
[INFO]: Epoch 28 of 350
Training


100%|██████████| 417/417 [08:32<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.63it/s]


0.001
Training loss: 2.369, training acc: 32.683
Validation loss: 2.318, validation acc: 35.344
--------------------------------------------------
[INFO]: Epoch 29 of 350
Training


100%|██████████| 417/417 [08:28<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 2.279, training acc: 34.348
Validation loss: 2.308, validation acc: 34.533
--------------------------------------------------
[INFO]: Epoch 30 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 2.227, training acc: 35.353
Validation loss: 2.344, validation acc: 35.074
--------------------------------------------------
[INFO]: Epoch 31 of 350
Training


100%|██████████| 417/417 [08:32<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 2.171, training acc: 36.883
Validation loss: 2.244, validation acc: 35.224
--------------------------------------------------
[INFO]: Epoch 32 of 350
Training


100%|██████████| 417/417 [08:34<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.62it/s]


0.001
Training loss: 2.149, training acc: 37.288
Validation loss: 2.153, validation acc: 39.064
--------------------------------------------------
[INFO]: Epoch 33 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 2.082, training acc: 39.403
Validation loss: 2.087, validation acc: 41.344
--------------------------------------------------
[INFO]: Epoch 34 of 350
Training


100%|██████████| 417/417 [08:26<00:00,  1.21s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 2.023, training acc: 40.843
Validation loss: 2.181, validation acc: 38.524
--------------------------------------------------
[INFO]: Epoch 35 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.62it/s]


0.001
Training loss: 1.978, training acc: 41.818
Validation loss: 2.142, validation acc: 38.284
--------------------------------------------------
[INFO]: Epoch 36 of 350
Training


100%|██████████| 417/417 [08:29<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 1.942, training acc: 42.208
Validation loss: 2.188, validation acc: 39.064
--------------------------------------------------
[INFO]: Epoch 37 of 350
Training


100%|██████████| 417/417 [08:34<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:11<00:00,  1.59it/s]


0.001
Training loss: 1.875, training acc: 43.678
Validation loss: 2.045, validation acc: 41.584
--------------------------------------------------
[INFO]: Epoch 38 of 350
Training


100%|██████████| 417/417 [08:35<00:00,  1.24s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.64it/s]


0.001
Training loss: 1.848, training acc: 44.608
Validation loss: 2.012, validation acc: 43.144
--------------------------------------------------
[INFO]: Epoch 39 of 350
Training


100%|██████████| 417/417 [08:29<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.62it/s]


0.001
Training loss: 1.780, training acc: 46.708
Validation loss: 2.300, validation acc: 39.094
--------------------------------------------------
[INFO]: Epoch 40 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:11<00:00,  1.59it/s]


0.001
Training loss: 1.750, training acc: 47.773
Validation loss: 1.833, validation acc: 48.275
--------------------------------------------------
[INFO]: Epoch 41 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.64it/s]


0.001
Training loss: 1.710, training acc: 48.193
Validation loss: 1.862, validation acc: 46.235
--------------------------------------------------
[INFO]: Epoch 42 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.64it/s]


0.001
Training loss: 1.695, training acc: 50.217
Validation loss: 1.800, validation acc: 46.865
--------------------------------------------------
[INFO]: Epoch 43 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 1.612, training acc: 51.147
Validation loss: 1.852, validation acc: 49.115
--------------------------------------------------
[INFO]: Epoch 44 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.62it/s]


0.001
Training loss: 1.619, training acc: 50.907
Validation loss: 2.019, validation acc: 44.314
--------------------------------------------------
[INFO]: Epoch 45 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:11<00:00,  1.59it/s]


0.001
Training loss: 1.568, training acc: 51.927
Validation loss: 1.853, validation acc: 48.065
--------------------------------------------------
[INFO]: Epoch 46 of 350
Training


100%|██████████| 417/417 [08:33<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 1.530, training acc: 53.457
Validation loss: 1.752, validation acc: 52.625
--------------------------------------------------
[INFO]: Epoch 47 of 350
Training


100%|██████████| 417/417 [08:32<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 1.510, training acc: 54.057
Validation loss: 1.721, validation acc: 52.595
--------------------------------------------------
[INFO]: Epoch 48 of 350
Training


100%|██████████| 417/417 [08:32<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 1.447, training acc: 56.052
Validation loss: 1.703, validation acc: 52.265
--------------------------------------------------
[INFO]: Epoch 49 of 350
Training


100%|██████████| 417/417 [08:36<00:00,  1.24s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.63it/s]


0.001
Training loss: 1.430, training acc: 57.027
Validation loss: 1.795, validation acc: 49.205
--------------------------------------------------
[INFO]: Epoch 50 of 350
Training


100%|██████████| 417/417 [08:28<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 1.408, training acc: 56.682
Validation loss: 1.761, validation acc: 50.705
--------------------------------------------------
[INFO]: Epoch 51 of 350
Training


100%|██████████| 417/417 [08:27<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:11<00:00,  1.59it/s]


0.001
Training loss: 1.354, training acc: 58.017
Validation loss: 1.688, validation acc: 53.405
--------------------------------------------------
[INFO]: Epoch 52 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 1.350, training acc: 58.347
Validation loss: 1.651, validation acc: 54.035
--------------------------------------------------
[INFO]: Epoch 53 of 350
Training


100%|██████████| 417/417 [08:28<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.63it/s]


0.001
Training loss: 1.310, training acc: 59.532
Validation loss: 1.626, validation acc: 54.305
--------------------------------------------------
[INFO]: Epoch 54 of 350
Training


100%|██████████| 417/417 [08:34<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 1.293, training acc: 60.402
Validation loss: 1.576, validation acc: 56.016
--------------------------------------------------
[INFO]: Epoch 55 of 350
Training


100%|██████████| 417/417 [08:29<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.61it/s]


0.001
Training loss: 1.246, training acc: 61.407
Validation loss: 1.584, validation acc: 56.346
--------------------------------------------------
[INFO]: Epoch 56 of 350
Training


100%|██████████| 417/417 [08:29<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.62it/s]


0.001
Training loss: 1.227, training acc: 62.187
Validation loss: 1.544, validation acc: 56.586
--------------------------------------------------
[INFO]: Epoch 57 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 1.208, training acc: 62.577
Validation loss: 1.566, validation acc: 56.766
--------------------------------------------------
[INFO]: Epoch 58 of 350
Training


100%|██████████| 417/417 [08:28<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 1.175, training acc: 63.597
Validation loss: 1.748, validation acc: 52.145
--------------------------------------------------
[INFO]: Epoch 59 of 350
Training


100%|██████████| 417/417 [08:26<00:00,  1.21s/it]


Validation


100%|██████████| 209/209 [02:08<00:00,  1.63it/s]


0.001
Training loss: 1.177, training acc: 63.147
Validation loss: 1.603, validation acc: 57.006
--------------------------------------------------
[INFO]: Epoch 60 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:12<00:00,  1.57it/s]


0.001
Training loss: 1.136, training acc: 64.692
Validation loss: 1.550, validation acc: 57.756
--------------------------------------------------
[INFO]: Epoch 61 of 350
Training


100%|██████████| 417/417 [08:34<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 1.128, training acc: 65.292
Validation loss: 1.492, validation acc: 59.016
--------------------------------------------------
[INFO]: Epoch 62 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:11<00:00,  1.59it/s]


0.001
Training loss: 1.125, training acc: 64.707
Validation loss: 1.523, validation acc: 58.056
--------------------------------------------------
[INFO]: Epoch 63 of 350
Training


100%|██████████| 417/417 [08:34<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 1.081, training acc: 66.402
Validation loss: 1.473, validation acc: 59.616
--------------------------------------------------
[INFO]: Epoch 64 of 350
Training


100%|██████████| 417/417 [08:29<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:09<00:00,  1.61it/s]


0.001
Training loss: 1.031, training acc: 67.797
Validation loss: 1.487, validation acc: 60.156
--------------------------------------------------
[INFO]: Epoch 65 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:14<00:00,  1.56it/s]


0.001
Training loss: 1.032, training acc: 67.812
Validation loss: 1.489, validation acc: 60.066
--------------------------------------------------
[INFO]: Epoch 66 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:12<00:00,  1.58it/s]


0.001
Training loss: 1.008, training acc: 68.202
Validation loss: 1.549, validation acc: 58.746
--------------------------------------------------
[INFO]: Epoch 67 of 350
Training


100%|██████████| 417/417 [08:30<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.61it/s]


0.001
Training loss: 1.011, training acc: 68.517
Validation loss: 1.463, validation acc: 59.976
--------------------------------------------------
[INFO]: Epoch 68 of 350
Training


100%|██████████| 417/417 [08:31<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:12<00:00,  1.58it/s]


0.001
Training loss: 0.957, training acc: 69.642
Validation loss: 1.488, validation acc: 60.486
--------------------------------------------------
[INFO]: Epoch 69 of 350
Training


100%|██████████| 417/417 [08:26<00:00,  1.21s/it]


Validation


100%|██████████| 209/209 [02:07<00:00,  1.64it/s]


0.001
Training loss: 0.949, training acc: 70.256
Validation loss: 1.572, validation acc: 58.776
--------------------------------------------------
[INFO]: Epoch 70 of 350
Training


100%|██████████| 417/417 [08:25<00:00,  1.21s/it]


Validation


100%|██████████| 209/209 [02:10<00:00,  1.60it/s]


0.001
Training loss: 0.956, training acc: 69.942
Validation loss: 1.501, validation acc: 59.586
--------------------------------------------------
[INFO]: Epoch 71 of 350
Training


100%|██████████| 417/417 [08:27<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:14<00:00,  1.56it/s]


0.001
Training loss: 0.886, training acc: 71.681
Validation loss: 1.537, validation acc: 60.006
--------------------------------------------------
[INFO]: Epoch 72 of 350
Training


100%|██████████| 417/417 [08:26<00:00,  1.21s/it]


Validation


100%|██████████| 209/209 [02:11<00:00,  1.59it/s]


0.001
Training loss: 0.918, training acc: 71.186
Validation loss: 1.419, validation acc: 61.896
--------------------------------------------------
[INFO]: Epoch 73 of 350
Training


100%|██████████| 417/417 [08:21<00:00,  1.20s/it]


Validation


100%|██████████| 209/209 [02:11<00:00,  1.59it/s]


0.001
Training loss: 0.909, training acc: 70.571
Validation loss: 1.420, validation acc: 61.536
--------------------------------------------------
[INFO]: Epoch 74 of 350
Training


100%|██████████| 417/417 [08:26<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:17<00:00,  1.52it/s]


0.001
Training loss: 0.840, training acc: 73.676
Validation loss: 1.425, validation acc: 63.126
--------------------------------------------------
[INFO]: Epoch 75 of 350
Training


100%|██████████| 417/417 [08:33<00:00,  1.23s/it]


Validation


100%|██████████| 209/209 [02:15<00:00,  1.54it/s]


0.001
Training loss: 0.876, training acc: 72.731
Validation loss: 1.416, validation acc: 62.916
--------------------------------------------------
[INFO]: Epoch 76 of 350
Training


100%|██████████| 417/417 [08:28<00:00,  1.22s/it]


Validation


100%|██████████| 209/209 [02:15<00:00,  1.54it/s]


0.001
Training loss: 0.830, training acc: 73.556
Validation loss: 1.541, validation acc: 59.826
--------------------------------------------------
[INFO]: Epoch 77 of 350
Training


 12%|█▏        | 51/417 [01:02<07:27,  1.22s/it]


KeyboardInterrupt: 

In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

for epoch in range(34, epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss, train_epoch_acc = train(
        model, 
        train_loader, 
        optimizer, 
        criterion,
        device,
        full_support_set
    )
    valid_epoch_loss, valid_epoch_acc = validate(
        model, 
        valid_loader, 
        criterion,
        device,
        full_support_set
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)
    train_acc.append(train_epoch_acc)
    valid_acc.append(valid_epoch_acc)
    print(learning_rate)
    torch.save(model.state_dict(), 'test_nw_fc_FGVCA2_w')
            
    print(f"Training loss: {train_epoch_loss:.3f}, training acc: {train_epoch_acc:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, validation acc: {valid_epoch_acc:.3f}")
    print('-'*50)

[INFO]: Epoch 35 of 350
Training


 37%|███▋      | 156/417 [03:13<05:19,  1.22s/it]