In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np

import utils

@torch.no_grad()
def validate(model, arch, classifier, epoch, dataloader, device="cpu", status="validation"):
    model.eval()
    classifier.eval()
    total_loss = 0.0
    total_samples = 0
    total_correct = 0
    all_outputs = []
    all_targets = []

    with torch.no_grad():
        for step, (inp, target) in enumerate(dataloader, start=1):
            inp = inp.to(device)
            target = target.to(device)
            batch_size = inp.size(0)

            # Forward pass
            if "vit" in arch:
                intermediate_output = model.get_intermediate_layers(inp, 1)
                output_features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
                avg_pooled = torch.mean(intermediate_output[-1][:, 1:], dim=1)
                output_features = torch.cat((output_features, avg_pooled), dim=-1)
            else:
                output_features = model(inp)
            output = classifier(output_features)

            # Compute loss
            loss = nn.CrossEntropyLoss()(output, target)

            # Accumulate loss and samples
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            # Compute accuracy
            _, pred = torch.max(output, dim=1)
            total_correct += pred.eq(target).sum().item()

            # Accumulate outputs and targets for F1-score computation
            all_outputs.append(output.cpu())
            all_targets.append(target.cpu())

    # Compute average loss and accuracy
    avg_loss = total_loss / total_samples
    avg_acc = 100.0 * total_correct / total_samples

    # Concatenate all outputs and targets
    all_outputs = torch.cat(all_outputs)
    all_targets = torch.cat(all_targets)

    # Compute F1-score
    avg_f1 = utils.f1_score(all_outputs, all_targets, classifier.num_labels)

    # Log statistics
    val_stat = {
        'val_loss': avg_loss,
        'val_acc1': avg_acc,
        'val_f1': avg_f1
    }
    if status == "validation":
        print(
            "[VALID] "
            f"epoch: {epoch + 1:03d} | "
            f"valid acc: {val_stat['val_acc1']:05.2f} | "
            f"valid f1: {val_stat['val_f1']:.4f}"
        )
    elif status == "test":
        print(
            "[TEST] "
            f"test acc: {val_stat['val_acc1']:05.2f} | "
            f"test f1: {val_stat['val_f1']:.4f}"
        )
    else:
        raise ValueError("Set the valid status. Avaliable status is \"validation\" or \"test\".")
    return val_stat

In [2]:
# Linear classifier training setting using DINO-pretrained ResNet-50 as backbone

MODEL_NAME = "resnet50"
SSL_METHOD = "DINO"
GPU_NUM = 1

num_labels = 39

epochs = 50
batch_size = 32
lr = 1e-4
log_interval = 10

device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')

In [3]:
# Initialize the model and classifier for plant disease classification.

model, classifier = utils.init_pretrained_model(MODEL_NAME, SSL_METHOD, num_labels=num_labels, device=device)

[ok] Loaded cleanly: LeafVision_DINO_resnet50.pth


In [4]:
# Implement your plant disease dataset/dataloader, optimizer, and scheduler (optional) here.

train_dataset_path = "./dataset/PV/05images/train"
valid_dataset_path = "./dataset/PV/05images/valid"
test_dataset_path = "./dataset/PV/test"

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4371, 0.5177, 0.3476), (0.1789, 0.1545, 0.1923)),
])
test_transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4371, 0.5177, 0.3476), (0.1789, 0.1545, 0.1923)),
])

train_set = datasets.ImageFolder(train_dataset_path, transform=train_transform)
valid_set = datasets.ImageFolder(valid_dataset_path, transform=test_transform)
test_set = datasets.ImageFolder(test_dataset_path, transform=test_transform)

train_loader = DataLoader(train_set, batch_size=batch_size)
valid_loader = DataLoader(valid_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)

optimizer = torch.optim.Adam(
    params=[
        {'params': classifier.parameters(), 'lr': lr}
    ],
    lr=lr
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=0)

In [5]:
# Train the linear classifier.

model.eval()
classifier.train()

best_valid_loss = np.inf
best_classifier = None

for epoch in range(epochs):
    total_loss = 0.0
    total_samples = 0
    total_correct = 0
    all_outputs = []
    all_targets = []

    for step, (inp, target) in enumerate(train_loader, start=1):
        inp = inp.to(device)
        target = target.to(device)
        batch_size = inp.size(0)

        # Forward pass
        with torch.no_grad():
            if "vit" in MODEL_NAME:
                intermediate_output = model.get_intermediate_layers(inp, 1)
                output_features = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
                avg_pooled = torch.mean(intermediate_output[-1][:, 1:], dim=1)
                output_features = torch.cat((output_features, avg_pooled), dim=-1)
            else:
                output_features = model(inp)
        output = classifier(output_features)

        # Compute loss
        loss = nn.CrossEntropyLoss()(output, target)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate loss and samples
        total_loss += loss.item() * batch_size
        total_samples += batch_size

        # Compute accuracy
        _, pred = torch.max(output, dim=1)
        total_correct += pred.eq(target).sum().item()

        # Accumulate outputs and targets for F1-score computation
        all_outputs.append(output.detach().cpu())
        all_targets.append(target.detach().cpu())

    if scheduler is not None:
        scheduler.step()
    
    # Compute average loss and accuracy
    avg_loss = total_loss / total_samples
    avg_acc = 100.0 * total_correct / total_samples

    # Concatenate all outputs and targets
    all_outputs = torch.cat(all_outputs)
    all_targets = torch.cat(all_targets)

    # Compute F1-score
    avg_f1 = utils.f1_score(all_outputs, all_targets, classifier.num_labels)

    # Log statistics
    train_stat = {
        'epoch': epoch + 1,
        'train_loss': avg_loss,
        'lr': optimizer.param_groups[0]["lr"],
        'train_acc1': avg_acc,
        'train_f1': avg_f1
    }

    print(
        "[TRAIN] "
        f"epoch: {epoch + 1:03d} | "
        f"train acc: {train_stat['train_acc1']:05.2f} | "
        f"train f1: {train_stat['train_f1']:.4f}"
    )

    val_stat = validate(model, MODEL_NAME, classifier, epoch, valid_loader, device=device)
    if val_stat['val_loss'] < best_valid_loss:
        best_valid_loss = val_stat['val_loss']
        best_classifier = classifier
        classifier.to(device)
    print(f"Finished training {epoch + 1} epoch")

[TRAIN] epoch: 001 | train acc: 00.51 | train f1: 0.0011
[VALID] epoch: 001 | valid acc: 02.56 | valid f1: 0.0114
Finished training 1 epoch
[TRAIN] epoch: 002 | train acc: 04.62 | train f1: 0.0359
[VALID] epoch: 002 | valid acc: 05.13 | valid f1: 0.0556
Finished training 2 epoch
[TRAIN] epoch: 003 | train acc: 10.77 | train f1: 0.0651
[VALID] epoch: 003 | valid acc: 10.26 | valid f1: 0.0893
Finished training 3 epoch
[TRAIN] epoch: 004 | train acc: 09.23 | train f1: 0.0629
[VALID] epoch: 004 | valid acc: 07.69 | valid f1: 0.0614
Finished training 4 epoch
[TRAIN] epoch: 005 | train acc: 08.72 | train f1: 0.0749
[VALID] epoch: 005 | valid acc: 07.69 | valid f1: 0.0613
Finished training 5 epoch
[TRAIN] epoch: 006 | train acc: 07.18 | train f1: 0.0570
[VALID] epoch: 006 | valid acc: 07.69 | valid f1: 0.0613
Finished training 6 epoch
[TRAIN] epoch: 007 | train acc: 07.69 | train f1: 0.0664
[VALID] epoch: 007 | valid acc: 12.82 | valid f1: 0.1212
Finished training 7 epoch
[TRAIN] epoch: 008 |

In [None]:
# Test the trained linear classifier.

validate(model, MODEL_NAME, best_classifier, epochs, test_loader, device=device, status="test")

[TEST] test acc: 86.37 | test f1: 0.8707


{'val_loss': 1.0286606836196703,
 'val_acc1': 86.37435897435897,
 'val_f1': 0.870705246925354}