In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import timm
import pandas as pd
from PIL import Image
import torchvision.transforms as T
from sklearn.metrics import classification_report
import numpy as np
from pathlib import Path

In [27]:
BATCH_SIZE = 32
IMG_SIZE = 224
EPOCHS = 30
LR = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "mps"  # MPS for Mac M4
NUM_WORKERS = 4
N_CLASSES = 2  # healthy vs unhealthy
N_DOMAINS = 3  # shelled, unshelled, mixed

In [19]:
class TamarindDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        df = pd.read_csv(csv_file)
        self.df = df
        self.transform = transform
        self.cls2idx = {"healthy":0, "unhealthy":1}
        self.dom2idx = {"shelled":0, "unshelled":1, "mixed":2}
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['path']).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.cls2idx[row['class']]
        domain = self.dom2idx.get(row['domain'], 0)
        return img, label, domain

In [20]:
train_tfms = T.Compose([
    T.RandomResizedCrop(IMG_SIZE),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.2,0.2,0.2,0.02),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

val_tfms = T.Compose([
    T.Resize((IMG_SIZE,IMG_SIZE)),
    T.ToTensor(),
    T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

In [21]:
class FiLMAdapter(nn.Module):
    def __init__(self, feat_dim, n_domains):
        super().__init__()
        self.gamma = nn.Embedding(n_domains, feat_dim)
        self.beta = nn.Embedding(n_domains, feat_dim)
    def forward(self, feats, domain_idx):
        gamma = self.gamma(domain_idx)
        beta = self.beta(domain_idx)
        return feats * (1 + gamma.unsqueeze(1)) + beta.unsqueeze(1)

In [22]:
class TeacherModel(nn.Module):
    def __init__(self, n_classes=2, n_domains=3):
        super().__init__()
        self.backbone = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=0)
        feat_dim = self.backbone.num_features
        self.film = FiLMAdapter(feat_dim, n_domains)
        self.classifier = nn.Linear(feat_dim, n_classes)
    def forward(self, x, domain_idx):
        feats = self.backbone(x)  # [B, feat_dim]
        feats = self.film(feats, domain_idx)
        logits = self.classifier(feats)
        return logits, feats

In [23]:
def contrastive_loss(feats, labels, temp=0.1):
    # Normalize
    feats = F.normalize(feats, dim=1)
    logits = feats @ feats.T / temp
    targets = torch.eq(labels.unsqueeze(1), labels.unsqueeze(0)).float()
    loss = F.cross_entropy(logits, targets.argmax(dim=1))
    return loss

In [24]:
def train_epoch(model, loader, opt, ce_loss, device):
    model.train()
    total_loss = 0
    for imgs, labels, domains in loader:
        imgs, labels, domains = imgs.to(device), labels.to(device), domains.to(device)
        opt.zero_grad()
        logits, feats = model(imgs, domains)
        loss_ce = ce_loss(logits, labels)
        loss_con = contrastive_loss(feats, labels)
        loss = loss_ce + 0.1 * loss_con
        loss.backward()
        opt.step()
        total_loss += loss.item()
    return total_loss / len(loader)

In [25]:
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    y_true, y_pred = [], []
    for imgs, labels, domains in loader:
        imgs, labels, domains = imgs.to(device), labels.to(device), domains.to(device)
        logits, _ = model(imgs, domains)
        preds = logits.argmax(1).cpu().numpy()
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds)
    print(classification_report(y_true, y_pred, target_names=["healthy","unhealthy"]))
    return np.mean(np.array(y_true)==np.array(y_pred))

In [26]:
train_ds = TamarindDataset("./output/split_train.csv", transform=train_tfms)
val_ds = TamarindDataset("./output/split_val.csv", transform=val_tfms)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

In [28]:
model = TeacherModel(N_CLASSES, N_DOMAINS).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
ce_loss = nn.CrossEntropyLoss()

In [29]:
best_acc = 0
for epoch in range(EPOCHS):
    loss = train_epoch(model, train_loader, opt, ce_loss, DEVICE)
    acc = evaluate(model, val_loader, DEVICE)
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss {loss:.4f}, Val Acc {acc:.3f}")
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), "teacher_best.pth")
print("Training done. Best val acc:", best_acc)

Traceback (most recent call last):
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    from multiprocessing.spawn import spawn_main; [31mspawn_main[0m[1;31m(tracker_fd=84, pipe_handle=101)[0m
                                                  [31m~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m122[0m, in [35mspawn_main[0m
    exitcode = _main(fd, parent_sentinel)
  File [35m"/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m132[0m, in [35m_main[0m
    self = reduction.pickle.load(from_parent)
[1;35mAttributeError[0m: [35mCan't get attribute 'TamarindDataset' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>[0m


KeyboardInterrupt: 