In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ToTensor()
])

ds_train = datasets.ImageFolder("~/datasets/kaggle-car-truck/train", transform=transform)
ds_valid = datasets.ImageFolder("~/datasets/kaggle-car-truck/valid", transform=transform)

dl_train = DataLoader(ds_train, batch_size=64, shuffle=True)
dl_valid = DataLoader(ds_valid, batch_size=64, shuffle=True)

dl_dict = {"train": dl_train, "valid": dl_valid}

In [None]:
for images, labels in dl_train:
    plt.imshow(images[0].numpy().transpose((1, 2, 0)))
    break

In [None]:
import torch.nn as nn
from torchvision import models
from torchvision.models.vgg import vgg16


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg16(weights=models.VGG16_Weights, progress=True)
        self.vgg.classifier[6] = nn.Linear(4096, 2)

        for param in self.vgg.features.parameters():
            param.requires_grad = False
        for param in self.vgg.avgpool.parameters():
            param.requires_grad = False
        for param in self.vgg.classifier.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = self.vgg(x)
        return x


net = Net()
net.train()

In [None]:
from torch import optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)

In [None]:
from tqdm import tqdm

device = torch.device("mps")
net = net.to(device)

accuracy_list = []
loss_list = []

num_epochs = 30
for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch+1, num_epochs))
    print('-------------')

    for phase in ["train", "valid"]:
        if phase == "train":
            net.train()
        else:
            net.eval()

        epoch_loss = 0.0
        epoch_corrects = 0

        for inputs, labels in tqdm(dl_dict[phase]):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            with torch.set_grad_enabled(phase == 'train'):
                outputs = net(inputs)

                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                if phase == 'train':
                    loss.backward()
                    optimizer.step()

                epoch_loss += loss.item() * inputs.size(0)  
                epoch_corrects += torch.sum(preds == labels.data)

            epoch_loss = epoch_loss / len(dl_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(dl_dict[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            if phase == 'val':
                accuracy_list.append(epoch_acc.item())
                loss_list.append(epoch_loss)

In [None]:
7 * 7 * 512