# TEST MODE

In [1]:
TEST_MODE=False

# Seeds

In [2]:
import numpy as np
import random
import torch
import os

seed=247

os.environ['PYTHONHASHSEED']=str(seed)
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


In [3]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Imports

In [4]:
import pandas as pd
from torch import nn
from sklearn.model_selection import train_test_split
import timm
import PIL 
from PIL import Image
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
import tqdm 
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, accuracy_score



# Loading data

In [5]:
train=pd.read_csv('/kaggle/input/plant-pathology-2020-fgvc7/train.csv')
test=pd.read_csv('/kaggle/input/plant-pathology-2020-fgvc7/test.csv')
sample=pd.read_csv('/kaggle/input/plant-pathology-2020-fgvc7/sample_submission.csv')

img_dir='/kaggle/input/plant-pathology-2020-fgvc7/images'

In [6]:
train_data, eval_data=train_test_split(train, test_size=0.15, random_state=seed)

# EDA

In [7]:
train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1821 entries, 0 to 1820
Data columns (total 5 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   image_id           1821 non-null   object
 1   healthy            1821 non-null   int64 
 2   multiple_diseases  1821 non-null   int64 
 3   rust               1821 non-null   int64 
 4   scab               1821 non-null   int64 
dtypes: int64(4), object(1)
memory usage: 71.3+ KB


# Dataset

## Dataset class

In [8]:
LABEL_COLS = ["healthy", "multiple_diseases", "rust", "scab"]
NUM_CLASSES = len(LABEL_COLS)

# ===== Dataset (multiclass) =====
class PlantDataset(Dataset):
    def __init__(self, df, img_dir, transforms=None, is_train=True):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transforms = transforms
        self.is_train = is_train

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row["image_id"]
        img_path = os.path.join(self.img_dir, f"{img_id}.jpg")
        image = Image.open(img_path).convert("RGB")

        if self.transforms is not None:
            image = self.transforms(image)

        if self.is_train:
            # ensure numeric -> to_numpy with dtype to avoid object array
            one_hot = row[LABEL_COLS].to_numpy(dtype=np.float32, copy=False)
            target = int(one_hot.argmax())   # 0..3
            return {"image": image, "labels": torch.tensor(target, dtype=torch.long), "image_id": img_id}
        else:
            return {"image": image, "image_id": img_id}

        

## Augmentations

In [9]:
IMG_SIZE=224

In [10]:
train_transforms=v2.Compose([
    v2.Resize((IMG_SIZE, IMG_SIZE)),
    v2.RandomVerticalFlip(p=0.5),
    v2.RandomHorizontalFlip(p=0.5),
    v2.RandomRotation(degrees=20),
    
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
    
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [11]:
eval_transforms=v2.Compose([
    v2.Resize((IMG_SIZE, IMG_SIZE)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [12]:
test_transforms=v2.Compose([
    v2.Resize((IMG_SIZE, IMG_SIZE)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## Dataset creation

In [13]:
train_dataset=PlantDataset(train_data, img_dir, train_transforms, is_train=True)
eval_dataset=PlantDataset(eval_data, img_dir, eval_transforms, is_train=True)
test_dataset=PlantDataset(test, img_dir, test_transforms, is_train=False)

## Dataloaders

### BATCH SIZE

In [14]:
BATCH_SIZE=32

In [15]:
train_dataloader=DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
eval_dataloader=DataLoader(eval_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_dataloader=DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [16]:
LABEL_COLS = ["healthy", "multiple_diseases", "rust", "scab"]
NUM_CLASSES = len(LABEL_COLS)

# 1) Class counts from your train split
counts = train_data[LABEL_COLS].sum().to_numpy(dtype=np.float64)  # shape (4,)
counts = np.clip(counts, 1, None)  # safety in case of zeros

# 2) Inverse-frequency weights, normalized to mean = 1 (stable for LR)
inv_freq = 1.0 / counts
weights = inv_freq / inv_freq.mean()  # e.g., array([w0, w1, w2, w3])

class_weights = torch.tensor(weights, dtype=torch.float32)

In [17]:
from torch.utils.data import WeightedRandomSampler

# class index (0..3) for each training row
y_idx = train_data[LABEL_COLS].to_numpy(dtype=np.float32).argmax(1)
y_idx_t = torch.tensor(y_idx, dtype=torch.long)

# per-sample weight = class_weight of its class
sample_weights = class_weights.cpu()[y_idx_t].numpy()

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),  # one "epoch" ≈ same #samples
    replacement=True
)

train_dataloader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE,
    sampler=sampler,  # <-- use sampler, not shuffle
    num_workers=4, pin_memory=True, persistent_workers=True
)

# Vision model

## EPOCHS

In [18]:
EPOCHS=15

## MODEL

In [19]:
MODEL_NAME='swin_large_patch4_window7_224'

In [20]:
model=timm.create_model(MODEL_NAME, pretrained=True, num_classes=4).to(device)

model.safetensors:   0%|          | 0.00/788M [00:00<?, ?B/s]

## Criterion

In [21]:
#criterion = nn.CrossEntropyLoss(label_smoothing=0.05)
LABEL_COLS = ["healthy", "multiple_diseases", "rust", "scab"]
NUM_CLASSES = len(LABEL_COLS)

# 1) Class counts from your train split
counts = train_data[LABEL_COLS].sum().to_numpy(dtype=np.float64)  # shape (4,)
counts = np.clip(counts, 1, None)  # safety in case of zeros

# 2) Inverse-frequency weights, normalized to mean = 1 (stable for LR)
inv_freq = 1.0 / counts
weights = inv_freq / inv_freq.mean()  # e.g., array([w0, w1, w2, w3])

class_weights = torch.tensor(weights, dtype=torch.float32)

# 3) Loss with both weight and label smoothing
criterion = nn.CrossEntropyLoss(
    weight=class_weights,
    label_smoothing=0.11
).to(device)

## Optimizer

In [22]:
#base_lr=0.01*(BATCH_SIZE/256)
BASE_LR_REF = 3e-4
base_lr = BASE_LR_REF * (BATCH_SIZE / 256)

In [23]:
#optimizer=torch.optim.SGD(model.parameters(), lr=base_lr)
optimizer = torch.optim.AdamW(model.parameters(), lr=base_lr, weight_decay=0.05)

## Scheduler

In [24]:
NUM_STEPS=len(train_dataloader)*EPOCHS

In [25]:
WARMUP_STEPS=NUM_STEPS*0.15

In [26]:
scheduler1=torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-3, end_factor=1.00, total_iters=WARMUP_STEPS)
scheduler2=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,  T_max=NUM_STEPS - WARMUP_STEPS, eta_min=base_lr*0.07)

scheduler=torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[scheduler1, scheduler2],
    milestones=[WARMUP_STEPS]
    
)

## Training loop

In [27]:

for epoch in range(EPOCHS):
    model.train()
    train_running_loss = 0.0
    all_probs, all_targets = [], []
    train_pbar = tqdm(train_dataloader, desc=f"Training {epoch+1}/{EPOCHS}: ", leave=False)

    for step, batch in enumerate(train_pbar):
        optimizer.zero_grad(set_to_none=True)
        X = batch['image'].to(device, non_blocking=True)
        y = batch['labels'].to(device)              # <- 'labels', not 'label'

        logits = model(X)                           # [B, 4]
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_running_loss += loss.detach().item() * X.size(0)

        probs = torch.nn.functional.softmax(logits, dim=1).detach().cpu().numpy()  # [B,4]
        all_probs.append(probs)
        all_targets.append(y.detach().cpu().numpy())

    all_probs = np.concatenate(all_probs) if all_probs else np.zeros((0, NUM_CLASSES))
    all_targets_idx = np.concatenate(all_targets) if all_targets else np.zeros((0,), dtype=int)
    # For multiclass ROC-AUC we need one-hot targets:
    all_targets_1h = np.eye(NUM_CLASSES, dtype=np.float32)[all_targets_idx]
    auc_macro = roc_auc_score(all_targets_1h, all_probs, average='macro', multi_class='ovr') if len(all_targets_idx) else float('nan')
    train_loss = train_running_loss / len(train_dataset)
    print(f"train: loss={train_loss:.4f}, AUC_macro={auc_macro:.4f}, acc={accuracy_score(all_targets_idx, all_probs.argmax(1)):.4f}")

    # ===== Eval =====
    model.eval()
    eval_running_loss = 0.0
    all_probs, all_targets = [], []
    eval_pbar = tqdm(eval_dataloader, desc=f"Evaluating {epoch+1}/{EPOCHS}: ", leave=False)

    with torch.no_grad():
        for batch in eval_pbar:
            X = batch['image'].to(device, non_blocking=True)
            y = batch['labels'].to(device)

            logits = model(X)
            loss = criterion(logits, y)
            eval_running_loss += loss.detach().item() * X.size(0)

            probs = torch.nn.functional.softmax(logits, dim=1).cpu().numpy()
            all_probs.append(probs)
            all_targets.append(y.cpu().numpy())

    all_probs = np.concatenate(all_probs) if all_probs else np.zeros((0, NUM_CLASSES))
    all_targets_idx = np.concatenate(all_targets) if all_targets else np.zeros((0,), dtype=int)
    all_targets_1h = np.eye(NUM_CLASSES, dtype=np.float32)[all_targets_idx]
    auc_macro = roc_auc_score(all_targets_1h, all_probs, average='macro', multi_class='ovr') if len(all_targets_idx) else float('nan')
    eval_loss = eval_running_loss / len(eval_dataset)
    print(f"eval:  loss={eval_loss:.4f}, AUC_macro={auc_macro:.4f}, acc={accuracy_score(all_targets_idx, all_probs.argmax(1)):.4f}")

                                                              

train: loss=1.0205, AUC_macro=0.6530, acc=0.2534


                                                              

eval:  loss=1.2249, AUC_macro=0.8979, acc=0.1898


                                                              

train: loss=0.5410, AUC_macro=0.9646, acc=0.6484


                                                              

eval:  loss=0.7768, AUC_macro=0.9791, acc=0.7993


                                                              

train: loss=0.3791, AUC_macro=0.9923, acc=0.9017


                                                              

eval:  loss=0.7019, AUC_macro=0.9890, acc=0.8869


                                                              

train: loss=0.3711, AUC_macro=0.9950, acc=0.9373


                                                              

eval:  loss=0.7059, AUC_macro=0.9924, acc=0.9453


                                                              

train: loss=0.3268, AUC_macro=0.9986, acc=0.9625


                                                              

eval:  loss=0.6824, AUC_macro=0.9959, acc=0.9562


                                                              

train: loss=0.3318, AUC_macro=0.9984, acc=0.9599


                                                              

eval:  loss=0.6760, AUC_macro=0.9929, acc=0.9453


                                                              

train: loss=0.3118, AUC_macro=0.9991, acc=0.9832


                                                              

eval:  loss=0.6935, AUC_macro=0.9929, acc=0.9672


                                                              

train: loss=0.3302, AUC_macro=0.9995, acc=0.9890


                                                              

eval:  loss=0.6918, AUC_macro=0.9947, acc=0.9781


                                                              

train: loss=0.3366, AUC_macro=0.9998, acc=0.9929


                                                              

eval:  loss=0.6775, AUC_macro=0.9953, acc=0.9818


                                                               

train: loss=0.3233, AUC_macro=1.0000, acc=0.9903


                                                               

eval:  loss=0.6986, AUC_macro=0.9938, acc=0.9818


                                                               

train: loss=0.3214, AUC_macro=1.0000, acc=0.9942


                                                               

eval:  loss=0.6734, AUC_macro=0.9942, acc=0.9818


                                                               

train: loss=0.2986, AUC_macro=0.9995, acc=0.9948


                                                               

eval:  loss=0.6849, AUC_macro=0.9943, acc=0.9818


                                                               

train: loss=0.3081, AUC_macro=1.0000, acc=0.9968


                                                               

eval:  loss=0.6857, AUC_macro=0.9942, acc=0.9818


                                                               

train: loss=0.3083, AUC_macro=1.0000, acc=0.9961


                                                               

eval:  loss=0.6849, AUC_macro=0.9944, acc=0.9818


                                                               

train: loss=0.3037, AUC_macro=1.0000, acc=0.9968


                                                               

eval:  loss=0.6814, AUC_macro=0.9944, acc=0.9818




# Submission

## no TTA 

In [28]:
model.eval()
test_probs = []
with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="test", leave=True):
        x = batch["image"].to(device, non_blocking=True)
        logits = model(x)
        probs  = torch.nn.functional.softmax(logits, dim=1).float().cpu().numpy()   # [B,4]
        test_probs.append(probs)

test_probs = np.concatenate(test_probs, axis=0)

sub = pd.DataFrame({
    'image_id': sample['image_id'],
    'healthy':            test_probs[:, 0],
    'multiple_diseases':  test_probs[:, 1],
    'rust':               test_probs[:, 2],
    'scab':               test_probs[:, 3],
})
sub.to_csv('submission.csv', index=False)


test: 100%|██████████| 57/57 [00:28<00:00,  1.98it/s]


## TTA 

In [29]:
def tta_views(x):
    return [
        x,
        torch.flip(x, dims=(-1,)),
        torch.flip(x, dims=(-2,)),
        torch.rot90(x, k=1, dims=(-2, -1)),
        torch.rot90(x, k=2, dims=(-2, -1)),
        torch.rot90(x, k=3, dims=(-2, -1)),
    ]

model.eval()
proba_batches = []
with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="test TTA", leave=False):
        x = batch["image"].to(device, non_blocking=True)
        acc = None
        views = tta_views(x)
        for v in views:
            z = model(v)                              # [B,4]
            acc = z if acc is None else acc + z
        logits = acc / len(views)                     # average logits
        probs  = torch.nn.functional.softmax(logits, dim=1).float().cpu().numpy()
        proba_batches.append(probs)

proba = np.concatenate(proba_batches, axis=0)

sub_tta = pd.DataFrame({
    'image_id': sample['image_id'],
    'healthy':            proba[:, 0],
    'multiple_diseases':  proba[:, 1],
    'rust':               proba[:, 2],
    'scab':               proba[:, 3],
})
sub_tta.to_csv('submission_tta.csv', index=False)


                                                         