In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [4]:
import torchvision
import torch

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [6]:
weights = torchvision.models.ConvNeXt_Base_Weights.IMAGENET1K_V1
model = torchvision.models.convnext_base(weights=weights).to(device)

Downloading: "https://download.pytorch.org/models/convnext_base-6075fbad.pth" to /root/.cache/torch/hub/checkpoints/convnext_base-6075fbad.pth
100%|██████████| 338M/338M [00:01<00:00, 195MB/s]  


In [8]:
from functools import partial
DefaultFlowers102 = partial(torchvision.datasets.Flowers102, root="datasets", transform=weights.transforms(), download=True)
train_set = DefaultFlowers102(split="train")
valid_set = DefaultFlowers102(split="val")
test_set = DefaultFlowers102(split="test")

100%|██████████| 345M/345M [00:09<00:00, 34.6MB/s] 
100%|██████████| 502/502 [00:00<00:00, 2.31MB/s]
100%|██████████| 15.0k/15.0k [00:00<00:00, 29.4MB/s]


In [9]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=32)
test_loader = DataLoader(test_set, batch_size=32)

In [10]:
print(model.classifier)

Sequential(
  (0): LayerNorm2d((1024,), eps=1e-06, elementwise_affine=True)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=1024, out_features=1000, bias=True)
)


In [12]:
import torch.nn as nn
n_classes = 102  # len(class_names) == 102
model.classifier[2] = nn.Linear(1024, n_classes).to(device)

In [13]:
print(model.classifier)

Sequential(
  (0): LayerNorm2d((1024,), eps=1e-06, elementwise_affine=True)
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): Linear(in_features=1024, out_features=102, bias=True)
)


In [14]:
for param in model.parameters():
    param.requires_grad = False

for param in model.classifier.parameters():
    param.requires_grad = True

In [15]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [32]:
model.to(device)

def train(model, num_epochs, train_dl, valid_dl):
    accuracy_hist_train = [0] * num_epochs
    accuracy_hist_valid = [0] * num_epochs
    
    for epoch in range(num_epochs):
        model.train()
        for train_X_batch, train_y_batch in train_dl:
            train_y_batch = train_y_batch.to(device)
            pred_y = model(train_X_batch.to(device))
            loss = loss_fn(pred_y, train_y_batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            predicted = torch.argmax(pred_y, 1)
            is_correct = (predicted == train_y_batch).sum().item()
            accuracy_hist_train[epoch] += is_correct
        
        accuracy_hist_train[epoch] /= len(train_dl.dataset)

        model.eval()
        with torch.no_grad():
            for x_batch, y_batch in valid_dl:
                y_batch = y_batch.to(device)
                pred = model(x_batch.to(device))
                loss = loss_fn(pred, y_batch)
                predicted = torch.argmax(pred, 1)
                is_correct = (predicted == y_batch).sum().item()
                accuracy_hist_valid[epoch] += is_correct
            
            accuracy_hist_valid[epoch] /= len(valid_dl.dataset)

        print(f'Epoch {epoch+1} accuracy: {accuracy_hist_train[epoch]:.4f} val_accuracy: {accuracy_hist_valid[epoch]:.4f}')

    return accuracy_hist_train, accuracy_hist_valid

In [33]:
torch.manual_seed(1)
num_epochs = 10
hist = train(model, num_epochs, train_loader, valid_loader)
print(hist)

Epoch 1 accuracy: 0.9137 val_accuracy: 0.8294
Epoch 2 accuracy: 0.9118 val_accuracy: 0.8510
Epoch 3 accuracy: 0.9451 val_accuracy: 0.8647
Epoch 4 accuracy: 0.9637 val_accuracy: 0.8833
Epoch 5 accuracy: 0.9725 val_accuracy: 0.8912
Epoch 6 accuracy: 0.9892 val_accuracy: 0.8863
Epoch 7 accuracy: 0.9863 val_accuracy: 0.8912
Epoch 8 accuracy: 0.9873 val_accuracy: 0.8990
Epoch 9 accuracy: 0.9951 val_accuracy: 0.9010
Epoch 10 accuracy: 0.9951 val_accuracy: 0.8941
([0.9137254901960784, 0.9117647058823529, 0.9450980392156862, 0.9637254901960784, 0.9725490196078431, 0.9892156862745098, 0.9862745098039216, 0.9872549019607844, 0.9950980392156863, 0.9950980392156863], [0.8294117647058824, 0.8509803921568627, 0.8647058823529412, 0.8833333333333333, 0.8911764705882353, 0.8862745098039215, 0.8911764705882353, 0.8990196078431373, 0.9009803921568628, 0.8941176470588236])
