In [None]:
!pip install torchmetrics

In [10]:
# Import Modules.
import zipfile
import torch
from torch import nn
from torch import optim
import torchvision
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import models
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from PIL import Image
from torchmetrics import ConfusionMatrix
from mlxtend.plotting import plot_confusion_matrix
import random
from pathlib import Path

In [None]:
# You can use this to load the model in github
# torch.load(f="Path/to/model")

In [None]:
# Set device and seed.
device = "cuda" if torch.cuda.is_available else "cpu"
torch.manual_seed(42)

In [None]:
# Download and freeze model.
weights = models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
model = models.convnext_tiny(weights=weights).to(device)

for param in model.features.parameters():
  param.requires_grad = False
# Get transforms for our model.
transform = weights.transforms()

In [None]:
# Create dataset objects.
trian_dataset = datasets.ImageFolder(
    root = "Path/to/dataset/train data",
    transform=transform,
)

test_dataset = datasets.ImageFolder(
    root = "Path/to/dataset/test data",
    transform=transform,
)


In [21]:
# Create dataloaders
trian_dataloader = DataLoader(
    trian_dataset,
    shuffle = True,
    batch_size = 32
)

test_dataloader = DataLoader(
    test_dataset,
    shuffle = False,
    batch_size = 32
)


In [6]:
# Loss function and optimizer.
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

In [7]:
# Training/Testing loop
total_loss = 0
total_test_loss = 0
total_acc = 0

for epoch in range(20):
  for X,Y in trian_dataloader:
    X,Y = X.to(device), Y.to(device)
    model.train()
    result = model(X)
    loss = loss_fn(result,Y)
    total_loss += loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  total_loss /= len(trian_dataloader)

  model.eval()
  with torch.inference_mode():
    for test_X,test_Y in test_dataloader:
      test_X,test_Y = test_X.to(device), test_Y.to(device)
      test_result = model(test_X)
      test_loss = loss_fn(test_result,test_Y)
      total_test_loss += test_loss
      total_acc += accuracy_score(test_Y.cpu(),torch.softmax(test_result,dim=1).argmax(dim=1).cpu())
    total_test_loss /= len(test_dataloader)
    total_acc /= len(test_dataloader)
  print(f"epoch: {epoch} | train loss {total_loss} | test loss {total_test_loss} acc | {total_acc}")