In [1]:
import torch
import torch.nn as nn
import torchvision.models as models

import utils
from trainer import Trainer

In [None]:
# Setting for the training session

MODEL_NAME = "resnet50"
SSL_METHOD = "DINO"
GPU_NUM = 0

num_labels = 3
epochs = 50
batch_size = 32
lr = 1e-4

device = torch.device(f'cuda:{GPU_NUM}' if torch.cuda.is_available() else 'cpu')

In [None]:
# Initialize the model and classifier for plant disease classification.

model, classifier = utils.init_pretrained_model(MODEL_NAME, SSL_METHOD, num_labels=num_labels, device=device)

[ok] Loaded cleanly: LeafVision_DINO_resnet50.pth


In [None]:
# Implement your plant disease dataset/dataloader, optimizer, and scheduler (optional) here.

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

train_dataset_path = "./dataset/Tomato/05images/train"
valid_dataset_path = "./dataset/Tomato/05images/valid"
test_dataaset_path = "./dataset/Tomato/test"

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4371, 0.5177, 0.3476), (0.1789, 0.1545, 0.1923)),
])
test_transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.4371, 0.5177, 0.3476), (0.1789, 0.1545, 0.1923)),
])

train_set = datasets.ImageFolder(train_dataset_path, transform=train_transform)
valid_set = datasets.ImageFolder(valid_dataset_path, transform=test_transform)
test_set = datasets.ImageFolder(test_dataset_path, transform=test_transform)

train_loader = DataLoader(train_set, batch_size=batch_size)
valid_loader = DataLoader(valid_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)

optimizer = torch.optim.Adam(
    params=[
        {'params': model.parameters(), 'lr': lr},
        {'params': classifier.parameters(), 'lr': lr}
    ],
    lr=lr
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, eta_min=0)

In [None]:
# TODO: train_loader, valid_loader, test_loader 구분
trainer = Trainer(
    model=model,
    classifier=classifier,
    arch=MODEL_NAME,
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    epochs=epochs,
    batch_size=batch_size,
    optimizer=optimizer,
    scheduler=scheduler
)