In [1]:
import os

import timm
import torch
from torch import nn
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from A.dataset import build_dataset, build_transform

In [3]:
NUM_FINETUNE_CLASSES = 5
batch_size = 260
learning_rate = 1e-3
epochs = 20
n_workers_per_device = 2
CUDA_VISIBLE_DEVICES = os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3")
NUM_WORKERS = n_workers_per_device * len(CUDA_VISIBLE_DEVICES.split(","))

In [4]:
model = timm.create_model(
    "convnextv2_base", pretrained=True, num_classes=NUM_FINETUNE_CLASSES
)
pretrained_cfg = model.pretrained_cfg

In [5]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [6]:
if torch.cuda.device_count() > 1:
    print(f"Let's use {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)

Let's use 4 GPUs!


In [7]:
train_transform, test_transform = build_transform(pretrained_cfg=pretrained_cfg)
train_ds, test_ds = build_dataset(
    "Datasets/train.csv",
    "Datasets/train_images",
    allowed_labels=[0, 4],
    train_transform=train_transform,
    test_transform=test_transform,
)
train_dataloader = DataLoader(train_ds, batch_size=batch_size, num_workers=NUM_WORKERS)
test_dataloader = DataLoader(test_ds, batch_size=batch_size, num_workers=NUM_WORKERS)

In [8]:
def train_loop(dataloader, model, loss_fn, optimizer):
    from tqdm import tqdm

    size = len(dataloader.dataset)
    # Set the model to training mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.train()
    with tqdm(total=len(dataloader.dataset)) as pbar:
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            # Compute prediction and loss
            pred = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            pbar.update(batch_size)
            # loss, current = loss.item(), batch * batch_size + len(X)
            # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    # Set the model to evaluation mode - important for batch normalization and dropout layers
    # Unnecessary in this situation but added for best practices
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    # Evaluating the model with torch.no_grad() ensures that no gradients are computed during test mode
    # also serves to reduce unnecessary gradient computations and memory usage for tensors with requires_grad=True
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n"
    )

In [9]:
model.to(device)
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)

Epoch 1
-------------------------------


3380it [00:24, 139.35it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.633935 

Epoch 2
-------------------------------


3380it [00:21, 160.36it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.632216 

Epoch 3
-------------------------------


3380it [00:22, 150.37it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.619957 

Epoch 4
-------------------------------


3380it [00:21, 158.51it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.614526 

Epoch 5
-------------------------------


3380it [00:21, 156.35it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.615841 

Epoch 6
-------------------------------


3380it [00:22, 152.79it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.591790 

Epoch 7
-------------------------------


3380it [00:22, 152.01it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.599282 

Epoch 8
-------------------------------


3380it [00:21, 156.50it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.583683 

Epoch 9
-------------------------------


3380it [00:22, 152.04it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.571710 

Epoch 10
-------------------------------


3380it [00:21, 155.32it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.597754 

Epoch 11
-------------------------------


3380it [00:22, 150.27it/s]                          


Test Error: 
 Accuracy: 74.7%, Avg loss: 0.555786 

Epoch 12
-------------------------------


3380it [00:22, 149.63it/s]                          


Test Error: 
 Accuracy: 70.3%, Avg loss: 0.571234 

Epoch 13
-------------------------------


3380it [00:21, 160.44it/s]                          


Test Error: 
 Accuracy: 73.8%, Avg loss: 0.561636 

Epoch 14
-------------------------------


3380it [00:21, 159.63it/s]                          


Test Error: 
 Accuracy: 74.9%, Avg loss: 0.536845 

Epoch 15
-------------------------------


3380it [00:21, 154.62it/s]                          


Test Error: 
 Accuracy: 75.2%, Avg loss: 0.536684 

Epoch 16
-------------------------------


3380it [00:20, 161.21it/s]                          


Test Error: 
 Accuracy: 76.6%, Avg loss: 0.540083 

Epoch 17
-------------------------------


3380it [00:21, 155.10it/s]                          


Test Error: 
 Accuracy: 76.8%, Avg loss: 0.508560 

Epoch 18
-------------------------------


3380it [00:21, 157.17it/s]                          


Test Error: 
 Accuracy: 74.9%, Avg loss: 0.552545 

Epoch 19
-------------------------------


3380it [00:21, 156.84it/s]                          


Test Error: 
 Accuracy: 76.6%, Avg loss: 0.506231 

Epoch 20
-------------------------------


3380it [00:22, 152.67it/s]                          


Test Error: 
 Accuracy: 77.4%, Avg loss: 0.496496 

