In [1]:
from tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import utils
from torch.utils import data
from torchvision.datasets import ImageFolder
from torchvision import transforms

In [2]:
train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.Grayscale(),
    transforms.ToTensor(),
])

In [3]:
train_dataset = ImageFolder('../data/train-cat-rabbit', transform=train_transforms)
train_dataset

Dataset ImageFolder
    Number of datapoints: 1600
    Root location: ../data/train-cat-rabbit
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
               RandomHorizontalFlip(p=0.5)
               RandomRotation(degrees=[-30.0, 30.0], interpolation=nearest, expand=False, fill=0)
               Grayscale(num_output_channels=1)
               ToTensor()
           )

In [41]:
test_transforms = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.Grayscale(),
    transforms.ToTensor(),
])
test_dataset = ImageFolder('../data/test-images', transform=test_transforms)
test_dataset

Dataset ImageFolder
    Number of datapoints: 15
    Root location: ../data/test-images
    StandardTransform
Transform: Compose(
               Resize(size=(256, 256), interpolation=bilinear, max_size=None, antialias=None)
               Grayscale(num_output_channels=1)
               ToTensor()
           )

In [42]:
validation_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.Grayscale(),
    transforms.ToTensor(),
])
val_dataset = ImageFolder('../data/val-cat-rabbit/', transform=validation_transforms)
val_dataset

Dataset ImageFolder
    Number of datapoints: 414
    Root location: ../data/val-cat-rabbit/
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
               Grayscale(num_output_channels=1)
               ToTensor()
           )

In [43]:
BATCH_SIZE = 64
train_loader = data.DataLoader(train_dataset,BATCH_SIZE,True)
test_loader = data.DataLoader(test_dataset,BATCH_SIZE)
val_loader = data.DataLoader(val_dataset,BATCH_SIZE,True)

In [7]:
sample_train = train_loader.__iter__().__next__()
sample_train

[tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           ...,
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]]],
 
 
         [[[0.0000, 0.0000, 0.3569,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.3608,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.3647,  ..., 0.3647, 0.3569, 0.3451],
           ...,
           [0.3176, 0.3451, 0.3647,  ..., 0.3020, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.2353, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.2745, 0.0000, 0.0000]]],
 
 
         [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
           [0.0000

In [8]:
sample_train[0][:16].shape

torch.Size([16, 1, 256, 256])

In [9]:
grid = utils.make_grid(sample_train[0]).detach().numpy().transpose([1,2,0])
grid.shape

(2066, 2066, 3)

In [10]:
# from skimage.viewer import ImageViewer
# ImageViewer(grid).show()

In [11]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv_1 = nn.Conv2d(1,32,3)
        self.conv_2 = nn.Conv2d(32,64,4)
        self.conv_3 = nn.Conv2d(64, 128, 3)
        self.avg_pool = nn.AvgPool2d(2)
        self.dense_1 = nn.Linear(128*30*30,64)
        self.dense_2 = nn.Linear(64,1)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = nn.functional.relu(self.conv_1(x))
        x = self.avg_pool(x)
        x = self.dropout(x)

        x = nn.functional.relu(self.conv_2(x))
        x = self.avg_pool(x)
        x = self.dropout(x)

        x = nn.functional.relu(self.conv_3(x))
        x = self.avg_pool(x)
        x = self.dropout(x)
        x = nn.Flatten()(x)
        x = nn.functional.relu(self.dense_1(x))

        x = self.dense_2(x)
        return torch.sigmoid(x)

model = ConvNet()
# load_data = torch.load('../models/model_checkpoint.pt')
# model.load_state_dict(load_data['model'])

In [12]:
optimizer = torch.optim.Adam(model.parameters(),0.001, weight_decay=0.00001)
# optimizer.load_state_dict(load_data['optimizers'])
criterion = nn.BCELoss()

In [13]:
def train(model, optimizer, criterion, data_loader, epoch:list[int,int],decay=1e-7,mode='train'):
    pbar = tqdm(data_loader)
    pbar.set_description_str(f'Epoch {epoch[0]}/{epoch[1]} ({mode})')
    running_loss = 0.0
    running_accurate = 0
    batch = None
    for images, labels in pbar:
        labels = labels.view((-1,1)).float()
        if mode == 'train':
            model.train()
            optimizer.zero_grad()
        else:
            model.eval()
        outs = model(images)
        preds = torch.round(outs)
        accurate = (preds == labels).sum()
        running_accurate += accurate
        if batch is None:
            batch = labels.size(0)
        accuracy = accurate / labels.size(0)
        loss = criterion(outs, labels)
        reg_loss = 0.0
        for param in model.parameters():
            reg_loss += torch.norm(param,2) * decay
        loss += reg_loss
        running_loss += loss

        if mode == 'train':
            loss.backward()
            optimizer.step()
        pbar.set_postfix({'loss':loss.item(),'accuracy':accuracy.item()})

    epoch_accuracy = running_accurate.item() / (len(data_loader) * batch)
    print(f'{mode}_loss: {running_loss.item():.4f}, {mode}_accuracy: {epoch_accuracy*100:.4f}%')
    return running_loss.item(), epoch_accuracy

In [14]:
def fit(model, optimizer, criterion, train_loader, validation_loader, epochs, decay=1e-7):
    train_info = {'loss':[],'accuracy':[], 'loss_val':[],'accuracy_val':[]}
    for epoch in range(1, epochs+1):
        loss, accuracy = train(model, optimizer, criterion, train_loader,[epoch, epochs],decay,'train')
        train_info['loss'].append(loss)
        train_info['accuracy'].append(accuracy)
        val_loss, val_accuracy = train(model, optimizer, criterion, validation_loader,[epoch, epochs],decay,'validation')
        train_info['loss_val'].append(val_loss)
        train_info['accuracy_val'].append(val_accuracy)
    return train_info

In [15]:
train_data = fit(model, optimizer, criterion, train_loader, val_loader, 50, decay=0)

Epoch 1/50 (train): 100%|██████████| 25/25 [01:59<00:00,  4.79s/it, loss=0.645, accuracy=0.594]


train_loss: 17.2293, train_accuracy: 59.5000%


Epoch 1/50 (validation): 100%|██████████| 7/7 [00:19<00:00,  2.85s/it, loss=0.619, accuracy=0.633]


validation_loss: 4.2082, validation_accuracy: 61.6071%


Epoch 2/50 (train): 100%|██████████| 25/25 [02:30<00:00,  6.01s/it, loss=0.595, accuracy=0.656]


train_loss: 15.0201, train_accuracy: 67.7500%


Epoch 2/50 (validation): 100%|██████████| 7/7 [00:16<00:00,  2.41s/it, loss=0.569, accuracy=0.567]


validation_loss: 4.0876, validation_accuracy: 62.7232%


Epoch 3/50 (train): 100%|██████████| 25/25 [01:52<00:00,  4.50s/it, loss=0.571, accuracy=0.734]


train_loss: 14.4934, train_accuracy: 68.5000%


Epoch 3/50 (validation): 100%|██████████| 7/7 [00:11<00:00,  1.68s/it, loss=0.626, accuracy=0.667]


validation_loss: 4.0967, validation_accuracy: 64.2857%


Epoch 4/50 (train): 100%|██████████| 25/25 [03:07<00:00,  7.51s/it, loss=0.625, accuracy=0.609]


train_loss: 14.3040, train_accuracy: 68.8750%


Epoch 4/50 (validation): 100%|██████████| 7/7 [00:20<00:00,  2.97s/it, loss=0.457, accuracy=0.733]


validation_loss: 3.9385, validation_accuracy: 63.6161%


Epoch 5/50 (train): 100%|██████████| 25/25 [03:23<00:00,  8.15s/it, loss=0.541, accuracy=0.75] 


train_loss: 14.2502, train_accuracy: 68.4375%


Epoch 5/50 (validation): 100%|██████████| 7/7 [00:15<00:00,  2.25s/it, loss=0.579, accuracy=0.667]


validation_loss: 3.9317, validation_accuracy: 66.0714%


Epoch 6/50 (train): 100%|██████████| 25/25 [02:05<00:00,  5.01s/it, loss=0.597, accuracy=0.688]


train_loss: 14.1354, train_accuracy: 69.5000%


Epoch 6/50 (validation): 100%|██████████| 7/7 [00:12<00:00,  1.75s/it, loss=0.505, accuracy=0.733]


validation_loss: 3.7891, validation_accuracy: 64.9554%


Epoch 7/50 (train): 100%|██████████| 25/25 [02:14<00:00,  5.40s/it, loss=0.581, accuracy=0.719]


train_loss: 13.5361, train_accuracy: 72.1875%


Epoch 7/50 (validation): 100%|██████████| 7/7 [00:17<00:00,  2.53s/it, loss=0.612, accuracy=0.667]


validation_loss: 3.8993, validation_accuracy: 63.8393%


Epoch 8/50 (train): 100%|██████████| 25/25 [02:42<00:00,  6.50s/it, loss=0.527, accuracy=0.688]


train_loss: 12.9423, train_accuracy: 72.5625%


Epoch 8/50 (validation): 100%|██████████| 7/7 [00:17<00:00,  2.47s/it, loss=0.679, accuracy=0.6]  


validation_loss: 3.9268, validation_accuracy: 64.9554%


Epoch 9/50 (train): 100%|██████████| 25/25 [02:42<00:00,  6.51s/it, loss=0.526, accuracy=0.719]


train_loss: 12.9279, train_accuracy: 72.1250%


Epoch 9/50 (validation): 100%|██████████| 7/7 [00:16<00:00,  2.34s/it, loss=0.46, accuracy=0.767] 


validation_loss: 3.5921, validation_accuracy: 66.0714%


Epoch 10/50 (train): 100%|██████████| 25/25 [02:25<00:00,  5.80s/it, loss=0.515, accuracy=0.734]


train_loss: 12.2208, train_accuracy: 74.9375%


Epoch 10/50 (validation): 100%|██████████| 7/7 [00:10<00:00,  1.44s/it, loss=0.367, accuracy=0.767]


validation_loss: 3.5533, validation_accuracy: 66.5179%


Epoch 11/50 (train): 100%|██████████| 25/25 [01:58<00:00,  4.72s/it, loss=0.455, accuracy=0.828]


train_loss: 12.2005, train_accuracy: 75.3125%


Epoch 11/50 (validation): 100%|██████████| 7/7 [00:13<00:00,  1.95s/it, loss=0.59, accuracy=0.7]   


validation_loss: 3.6673, validation_accuracy: 66.7411%


Epoch 12/50 (train): 100%|██████████| 25/25 [02:25<00:00,  5.82s/it, loss=0.48, accuracy=0.734] 


train_loss: 11.6151, train_accuracy: 75.1250%


Epoch 12/50 (validation): 100%|██████████| 7/7 [00:24<00:00,  3.57s/it, loss=0.526, accuracy=0.767]


validation_loss: 3.4597, validation_accuracy: 68.3036%


Epoch 13/50 (train): 100%|██████████| 25/25 [03:35<00:00,  8.63s/it, loss=0.455, accuracy=0.766]


train_loss: 10.7169, train_accuracy: 79.2500%


Epoch 13/50 (validation): 100%|██████████| 7/7 [00:22<00:00,  3.19s/it, loss=0.395, accuracy=0.8]  


validation_loss: 3.4105, validation_accuracy: 68.5268%


Epoch 14/50 (train): 100%|██████████| 25/25 [03:29<00:00,  8.36s/it, loss=0.334, accuracy=0.812]


train_loss: 10.6200, train_accuracy: 77.9375%


Epoch 14/50 (validation): 100%|██████████| 7/7 [00:19<00:00,  2.77s/it, loss=0.532, accuracy=0.733]


validation_loss: 3.4670, validation_accuracy: 71.4286%


Epoch 15/50 (train): 100%|██████████| 25/25 [03:21<00:00,  8.05s/it, loss=0.363, accuracy=0.797]


train_loss: 9.9935, train_accuracy: 80.7500%


Epoch 15/50 (validation): 100%|██████████| 7/7 [00:22<00:00,  3.21s/it, loss=0.374, accuracy=0.867]


validation_loss: 3.3664, validation_accuracy: 70.0893%


Epoch 16/50 (train): 100%|██████████| 25/25 [03:09<00:00,  7.60s/it, loss=0.535, accuracy=0.75] 


train_loss: 9.6591, train_accuracy: 82.2500%


Epoch 16/50 (validation): 100%|██████████| 7/7 [00:19<00:00,  2.81s/it, loss=0.351, accuracy=0.9]  


validation_loss: 3.1807, validation_accuracy: 73.2143%


Epoch 17/50 (train): 100%|██████████| 25/25 [03:07<00:00,  7.48s/it, loss=0.342, accuracy=0.797]


train_loss: 8.8954, train_accuracy: 84.1250%


Epoch 17/50 (validation): 100%|██████████| 7/7 [00:14<00:00,  2.10s/it, loss=0.51, accuracy=0.733] 


validation_loss: 3.3058, validation_accuracy: 72.7679%


Epoch 18/50 (train): 100%|██████████| 25/25 [02:24<00:00,  5.79s/it, loss=0.381, accuracy=0.812]


train_loss: 8.2529, train_accuracy: 84.9375%


Epoch 18/50 (validation): 100%|██████████| 7/7 [00:19<00:00,  2.86s/it, loss=0.65, accuracy=0.767] 


validation_loss: 3.5181, validation_accuracy: 73.4375%


Epoch 19/50 (train): 100%|██████████| 25/25 [02:19<00:00,  5.59s/it, loss=0.233, accuracy=0.906]


train_loss: 8.1685, train_accuracy: 84.6250%


Epoch 19/50 (validation): 100%|██████████| 7/7 [00:20<00:00,  2.96s/it, loss=0.485, accuracy=0.833]


validation_loss: 3.6426, validation_accuracy: 73.2143%


Epoch 20/50 (train): 100%|██████████| 25/25 [02:41<00:00,  6.47s/it, loss=0.352, accuracy=0.828]


train_loss: 8.3097, train_accuracy: 85.1875%


Epoch 20/50 (validation): 100%|██████████| 7/7 [00:14<00:00,  2.10s/it, loss=0.352, accuracy=0.833]


validation_loss: 3.3973, validation_accuracy: 70.9821%


Epoch 21/50 (train): 100%|██████████| 25/25 [02:29<00:00,  6.00s/it, loss=0.262, accuracy=0.906]


train_loss: 7.8260, train_accuracy: 85.8750%


Epoch 21/50 (validation): 100%|██████████| 7/7 [00:22<00:00,  3.27s/it, loss=0.667, accuracy=0.833]


validation_loss: 3.7157, validation_accuracy: 74.3304%


Epoch 22/50 (train): 100%|██████████| 25/25 [02:44<00:00,  6.59s/it, loss=0.199, accuracy=0.922]


train_loss: 7.2895, train_accuracy: 87.3125%


Epoch 22/50 (validation): 100%|██████████| 7/7 [00:12<00:00,  1.74s/it, loss=0.602, accuracy=0.733]


validation_loss: 3.3901, validation_accuracy: 74.3304%


Epoch 23/50 (train): 100%|██████████| 25/25 [02:29<00:00,  6.00s/it, loss=0.336, accuracy=0.844]


train_loss: 6.8953, train_accuracy: 88.3125%


Epoch 23/50 (validation): 100%|██████████| 7/7 [00:14<00:00,  2.05s/it, loss=0.47, accuracy=0.767] 


validation_loss: 3.2372, validation_accuracy: 73.4375%


Epoch 24/50 (train): 100%|██████████| 25/25 [02:29<00:00,  5.98s/it, loss=0.217, accuracy=0.859]


train_loss: 6.4982, train_accuracy: 88.8750%


Epoch 24/50 (validation): 100%|██████████| 7/7 [00:12<00:00,  1.85s/it, loss=0.57, accuracy=0.733] 


validation_loss: 3.4702, validation_accuracy: 73.8839%


Epoch 25/50 (train): 100%|██████████| 25/25 [02:19<00:00,  5.58s/it, loss=0.182, accuracy=0.938]


train_loss: 5.7813, train_accuracy: 91.5625%


Epoch 25/50 (validation): 100%|██████████| 7/7 [00:17<00:00,  2.51s/it, loss=0.455, accuracy=0.767]


validation_loss: 3.7469, validation_accuracy: 75.2232%


Epoch 26/50 (train): 100%|██████████| 25/25 [03:10<00:00,  7.62s/it, loss=0.194, accuracy=0.969]


train_loss: 6.6086, train_accuracy: 88.8750%


Epoch 26/50 (validation): 100%|██████████| 7/7 [00:12<00:00,  1.76s/it, loss=0.35, accuracy=0.767] 


validation_loss: 3.1636, validation_accuracy: 73.6607%


Epoch 27/50 (train): 100%|██████████| 25/25 [02:35<00:00,  6.20s/it, loss=0.164, accuracy=0.922]


train_loss: 5.4767, train_accuracy: 90.7500%


Epoch 27/50 (validation): 100%|██████████| 7/7 [00:12<00:00,  1.73s/it, loss=0.39, accuracy=0.833] 


validation_loss: 3.3997, validation_accuracy: 74.5536%


Epoch 28/50 (train): 100%|██████████| 25/25 [02:03<00:00,  4.93s/it, loss=0.281, accuracy=0.922]


train_loss: 5.1026, train_accuracy: 91.8750%


Epoch 28/50 (validation): 100%|██████████| 7/7 [00:11<00:00,  1.68s/it, loss=0.197, accuracy=0.9]  


validation_loss: 3.5961, validation_accuracy: 75.2232%


Epoch 29/50 (train): 100%|██████████| 25/25 [02:11<00:00,  5.25s/it, loss=0.13, accuracy=0.953] 


train_loss: 4.5989, train_accuracy: 92.3750%


Epoch 29/50 (validation): 100%|██████████| 7/7 [00:14<00:00,  2.04s/it, loss=0.611, accuracy=0.867]


validation_loss: 3.9136, validation_accuracy: 74.3304%


Epoch 30/50 (train): 100%|██████████| 25/25 [03:07<00:00,  7.49s/it, loss=0.288, accuracy=0.875]


train_loss: 5.0193, train_accuracy: 91.1250%


Epoch 30/50 (validation): 100%|██████████| 7/7 [00:13<00:00,  1.95s/it, loss=0.417, accuracy=0.833]


validation_loss: 4.0211, validation_accuracy: 73.4375%


Epoch 31/50 (train): 100%|██████████| 25/25 [02:06<00:00,  5.07s/it, loss=0.186, accuracy=0.922]


train_loss: 4.6584, train_accuracy: 92.1250%


Epoch 31/50 (validation): 100%|██████████| 7/7 [00:12<00:00,  1.79s/it, loss=0.511, accuracy=0.767]


validation_loss: 3.6524, validation_accuracy: 76.1161%


Epoch 32/50 (train): 100%|██████████| 25/25 [02:06<00:00,  5.06s/it, loss=0.294, accuracy=0.891] 


train_loss: 4.0755, train_accuracy: 94.1250%


Epoch 32/50 (validation): 100%|██████████| 7/7 [00:13<00:00,  1.96s/it, loss=0.427, accuracy=0.833]


validation_loss: 4.5964, validation_accuracy: 74.5536%


Epoch 33/50 (train): 100%|██████████| 25/25 [02:20<00:00,  5.61s/it, loss=0.132, accuracy=0.938] 


train_loss: 3.6884, train_accuracy: 94.5000%


Epoch 33/50 (validation): 100%|██████████| 7/7 [00:21<00:00,  3.01s/it, loss=1.3, accuracy=0.7]    


validation_loss: 4.5792, validation_accuracy: 76.3393%


Epoch 34/50 (train): 100%|██████████| 25/25 [02:38<00:00,  6.34s/it, loss=0.173, accuracy=0.922]


train_loss: 3.4205, train_accuracy: 94.7500%


Epoch 34/50 (validation): 100%|██████████| 7/7 [00:12<00:00,  1.72s/it, loss=0.491, accuracy=0.867]


validation_loss: 4.5870, validation_accuracy: 73.8839%


Epoch 35/50 (train): 100%|██████████| 25/25 [02:02<00:00,  4.90s/it, loss=0.123, accuracy=0.969] 


train_loss: 3.7553, train_accuracy: 94.4375%


Epoch 35/50 (validation): 100%|██████████| 7/7 [00:11<00:00,  1.59s/it, loss=0.821, accuracy=0.7]  


validation_loss: 4.6388, validation_accuracy: 73.4375%


Epoch 36/50 (train): 100%|██████████| 25/25 [01:52<00:00,  4.52s/it, loss=0.143, accuracy=0.938] 


train_loss: 2.9027, train_accuracy: 95.3750%


Epoch 36/50 (validation): 100%|██████████| 7/7 [00:12<00:00,  1.72s/it, loss=0.311, accuracy=0.9]  


validation_loss: 4.3487, validation_accuracy: 75.2232%


Epoch 37/50 (train): 100%|██████████| 25/25 [01:47<00:00,  4.30s/it, loss=0.0819, accuracy=0.969]


train_loss: 3.2783, train_accuracy: 95.0000%


Epoch 37/50 (validation): 100%|██████████| 7/7 [00:09<00:00,  1.37s/it, loss=0.773, accuracy=0.8]  


validation_loss: 4.8962, validation_accuracy: 74.1071%


Epoch 38/50 (train): 100%|██████████| 25/25 [01:43<00:00,  4.13s/it, loss=0.0843, accuracy=0.953]


train_loss: 2.5748, train_accuracy: 96.4375%


Epoch 38/50 (validation): 100%|██████████| 7/7 [00:09<00:00,  1.42s/it, loss=0.554, accuracy=0.867]


validation_loss: 4.6701, validation_accuracy: 75.0000%


Epoch 39/50 (train): 100%|██████████| 25/25 [01:43<00:00,  4.14s/it, loss=0.118, accuracy=0.969] 


train_loss: 2.9777, train_accuracy: 95.2500%


Epoch 39/50 (validation): 100%|██████████| 7/7 [00:09<00:00,  1.42s/it, loss=0.404, accuracy=0.833]


validation_loss: 4.8761, validation_accuracy: 73.8839%


Epoch 40/50 (train): 100%|██████████| 25/25 [01:43<00:00,  4.14s/it, loss=0.0617, accuracy=0.969]


train_loss: 2.1166, train_accuracy: 97.0625%


Epoch 40/50 (validation): 100%|██████████| 7/7 [00:09<00:00,  1.43s/it, loss=0.556, accuracy=0.8]  


validation_loss: 5.4594, validation_accuracy: 74.7768%


Epoch 41/50 (train): 100%|██████████| 25/25 [01:44<00:00,  4.16s/it, loss=0.0887, accuracy=0.953]


train_loss: 2.3026, train_accuracy: 96.9375%


Epoch 41/50 (validation): 100%|██████████| 7/7 [00:09<00:00,  1.41s/it, loss=0.229, accuracy=0.9]  


validation_loss: 5.3325, validation_accuracy: 75.4464%


Epoch 42/50 (train):  64%|██████▍   | 16/25 [01:18<00:44,  4.93s/it, loss=0.0467, accuracy=0.984]


KeyboardInterrupt: 

In [16]:
torch.save({'model':model.state_dict(),'optimizers':optimizer.state_dict()},'../models/model_checkpoint.pt')

In [45]:
train(model, optimizer, criterion, test_loader,[0,1],0,'test')

Epoch 0/1 (test): 100%|██████████| 1/1 [00:00<00:00,  1.98it/s, loss=0.225, accuracy=0.867]

test_loss: 0.2253, test_accuracy: 86.6667%





(0.22528241574764252, 0.8666666666666667)