In [1]:
import os
import sys
import gc
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from pathlib import Path
from sklearn.model_selection import StratifiedKFold
# SWA Imports
from torch.optim.swa_utils import AveragedModel, SWALR

# ==========================================
# 1. CONFIGURATION
# ==========================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 320      # Winning Resolution
BATCH_SIZE = 16
EPOCHS = 25         # INCREASED: MixUp needs time!
SWA_START = 18      # Start averaging late in training
LEARNING_RATE = 2e-4
N_FOLDS = 5
IMAGE_DIR = Path("/kaggle/input/csiro-biomass")
TARGET_COLS = ['Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g', 'GDM_g', 'Dry_Total_g']

print(f"Running ResNet34 MixUp + SWA (Long Training) on {DEVICE}...")

# ==========================================
# 2. 4-CHANNEL DATASET (Winning Config)
# ==========================================
class Biomass4ChannelDataset(Dataset):
    def __init__(self, df, target_cols=None, is_test=False):
        self.df = df.reset_index(drop=True)
        self.target_cols = target_cols
        self.is_test = is_test
        self.root_dir = IMAGE_DIR
        self.mean = torch.tensor([0.485, 0.456, 0.406, 0.5]).view(4,1,1)
        self.std = torch.tensor([0.229, 0.224, 0.225, 0.5]).view(4,1,1)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        rel_path = self.df.loc[idx, "image_path"]
        img_path = self.root_dir / rel_path
        try:
            pil_img = Image.open(img_path).convert("RGB").resize((IMG_SIZE, IMG_SIZE))
            img = np.array(pil_img).astype(np.float32) / 255.0
        except:
            img = np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.float32)

        # Standard ExG
        r, g, b = img[:,:,0], img[:,:,1], img[:,:,2]
        exg = (2 * g) - r - b
        exg = (exg - exg.min()) / (exg.max() - exg.min() + 1e-6)
        
        img_4c = np.dstack((img, exg))
        image = torch.tensor(img_4c.transpose(2, 0, 1), dtype=torch.float32)
        
        if not self.is_test:
            if np.random.random() > 0.5: image = torch.flip(image, dims=[2])
            if np.random.random() > 0.5: image = torch.flip(image, dims=[1])

        image = (image - self.mean) / self.std

        if self.is_test:
            img_id = Path(rel_path).stem 
            return image, img_id
        else:
            targets = self.df.loc[idx, self.target_cols].values.astype(float)
            targets = np.log1p(targets)
            return image, torch.tensor(targets, dtype=torch.float32)

# ==========================================
# 3. MODEL (ResNet34)
# ==========================================
def get_model():
    model = models.resnet34(weights=None)
    # Search Weights
    weights_path = None
    for dirname, _, filenames in os.walk('/kaggle/input'):
        for filename in filenames:
            if 'resnet34' in filename and '.pth' in filename:
                weights_path = os.path.join(dirname, filename)
                break
        if weights_path: break
            
    if weights_path:
        try: model.load_state_dict(torch.load(weights_path, weights_only=False))
        except: pass
    else:
        # Fallback
        model = models.resnet18(weights=None)
        for dirname, _, filenames in os.walk('/kaggle/input'):
            for filename in filenames:
                if 'resnet18' in filename and '.pth' in filename:
                    weights_path = os.path.join(dirname, filename)
                    break
            if weights_path: break
        if weights_path: model.load_state_dict(torch.load(weights_path, weights_only=False))

    # Adapter
    original_conv1 = model.conv1
    model.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
    with torch.no_grad():
        model.conv1.weight[:, :3, :, :] = original_conv1.weight
        model.conv1.weight[:, 3:4, :, :] = torch.mean(original_conv1.weight, dim=1, keepdim=True)

    # Deeper Head (Small Upgrade for capacity)
    model.fc = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(model.fc.in_features, 5)
    )
    return model.to(DEVICE)

# ==========================================
# 4. WEIGHTED LOSS
# ==========================================
class WeightedHuberLoss(nn.Module):
    def __init__(self, delta=1.0):
        super().__init__()
        self.huber = nn.HuberLoss(reduction='none', delta=delta)
        self.weights = torch.tensor([0.1, 0.1, 0.1, 0.2, 0.5]).to(DEVICE)
    def forward(self, preds, targets):
        return (self.huber(preds, targets) * self.weights).mean()

# ==========================================
# 5. SWA TRAINING LOOP
# ==========================================
raw_df = pd.read_csv("/kaggle/input/csiro-biomass/train.csv")
train_pivot = raw_df.pivot(index='image_path', columns='target_name', values='target').reset_index().fillna(0.0)
train_pivot['bin'] = pd.qcut(train_pivot['Dry_Total_g'], q=10, labels=False, duplicates='drop')

skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

for fold, (train_idx, val_idx) in enumerate(skf.split(train_pivot, train_pivot['bin'])):
    print(f"\n=== FOLD {fold+1}/{N_FOLDS} (MixUp + SWA) ===")
    
    train_loader = DataLoader(
        Biomass4ChannelDataset(train_pivot.iloc[train_idx], TARGET_COLS),
        batch_size=BATCH_SIZE, shuffle=True, num_workers=2
    )
    valid_loader = DataLoader(
        Biomass4ChannelDataset(train_pivot.iloc[val_idx], TARGET_COLS),
        batch_size=BATCH_SIZE, shuffle=False, num_workers=2
    )
    
    model = get_model()
    swa_model = AveragedModel(model)
    
    criterion = WeightedHuberLoss(delta=1.0)
    optimizer = Adam(model.parameters(), lr=LEARNING_RATE)
    
    # Scheduler: Cosine until SWA start, then constant SWA LR
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=SWA_START)
    swa_scheduler = SWALR(optimizer, swa_lr=1e-5)
    
    for epoch in range(EPOCHS):
        model.train()
        for x, y in train_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            optimizer.zero_grad()
            
            # --- MIXUP LOGIC ---
            if np.random.random() < 0.5:
                lam = np.random.beta(1.0, 1.0)
                index = torch.randperm(x.size(0)).to(DEVICE)
                mixed_x = lam * x + (1 - lam) * x[index]
                # Unlog -> Mix -> Relog
                y_lin_a = torch.expm1(y)
                y_lin_b = torch.expm1(y[index])
                mixed_y = torch.log1p(lam * y_lin_a + (1 - lam) * y_lin_b)
                preds = model(mixed_x)
                loss = criterion(preds, mixed_y)
            else:
                preds = model(x)
                loss = criterion(preds, y)
                
            loss.backward()
            optimizer.step()
            
        # SWA Phase
        if epoch >= SWA_START:
            swa_model.update_parameters(model)
            swa_scheduler.step()
        else:
            scheduler.step()
            
        # Quick Val Check
        if (epoch+1) % 5 == 0:
            model.eval()
            val_loss = 0
            with torch.no_grad():
                for x, y in valid_loader:
                    x, y = x.to(DEVICE), y.to(DEVICE)
                    val_loss += criterion(model(x), y).item()
            print(f"  Epoch {epoch+1} Val Loss: {val_loss/len(valid_loader):.4f}")

    # Finalize SWA
    print("Updating SWA Batch Norm...")
    torch.optim.swa_utils.update_bn(train_loader, swa_model, device=DEVICE)
    torch.save(swa_model.state_dict(), f"model_fold{fold}.pth")
    
    del model, swa_model, optimizer, train_loader, valid_loader
    torch.cuda.empty_cache()
    gc.collect()

# ==========================================
# 6. INFERENCE
# ==========================================
print("Starting Inference...")
test_df_raw = pd.read_csv("/kaggle/input/csiro-biomass/test.csv")
test_unique = test_df_raw[['image_path']].drop_duplicates()
test_ds = Biomass4ChannelDataset(test_unique, is_test=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

ensemble_preds = {} 

for fold in range(N_FOLDS):
    # Load SWA Model
    base_model = get_model()
    model = AveragedModel(base_model)
    model.load_state_dict(torch.load(f"model_fold{fold}.pth", weights_only=True))
    model.eval()
    
    with torch.no_grad():
        for images, img_ids in test_loader:
            images = images.to(DEVICE)
            
            # TTA: 3 Views
            p1 = model(images)
            p2 = model(torch.flip(images, dims=[3]))
            p3 = model(torch.flip(images, dims=[2]))
            
            avg_log = (p1 + p2 + p3) / 3.0
            preds = np.expm1(avg_log.cpu().numpy())
            
            for i, img_id in enumerate(img_ids):
                if img_id not in ensemble_preds: ensemble_preds[img_id] = np.zeros(5)
                ensemble_preds[img_id] += preds[i]
    
    del model, base_model
    torch.cuda.empty_cache()
    gc.collect()

results = []
for img_id, total_preds in ensemble_preds.items():
    avg_pred = total_preds / N_FOLDS
    for j, col in enumerate(TARGET_COLS):
        results.append({'sample_id': f"{img_id}__{col}", 'target': float(avg_pred[j])})

submission_df = pd.DataFrame(results)
submission_df['target'] = submission_df['target'].clip(lower=0.0)
submission_df.to_csv("submission.csv", index=False)
print("SWA + MixUp + ResNet34 Submission Ready.")

Running ResNet34 MixUp + SWA (Long Training) on cuda...

=== FOLD 1/5 (MixUp + SWA) ===
  Epoch 5 Val Loss: 0.0344
  Epoch 10 Val Loss: 0.0325
  Epoch 15 Val Loss: 0.0258
  Epoch 20 Val Loss: 0.0270
  Epoch 25 Val Loss: 0.0267
Updating SWA Batch Norm...

=== FOLD 2/5 (MixUp + SWA) ===
  Epoch 5 Val Loss: 0.0275
  Epoch 10 Val Loss: 0.0242
  Epoch 15 Val Loss: 0.0165
  Epoch 20 Val Loss: 0.0168
  Epoch 25 Val Loss: 0.0184
Updating SWA Batch Norm...

=== FOLD 3/5 (MixUp + SWA) ===
  Epoch 5 Val Loss: 0.0322
  Epoch 10 Val Loss: 0.0243
  Epoch 15 Val Loss: 0.0210
  Epoch 20 Val Loss: 0.0205
  Epoch 25 Val Loss: 0.0209
Updating SWA Batch Norm...

=== FOLD 4/5 (MixUp + SWA) ===
  Epoch 5 Val Loss: 0.0322
  Epoch 10 Val Loss: 0.0265
  Epoch 15 Val Loss: 0.0191
  Epoch 20 Val Loss: 0.0216
  Epoch 25 Val Loss: 0.0209
Updating SWA Batch Norm...

=== FOLD 5/5 (MixUp + SWA) ===
  Epoch 5 Val Loss: 0.0341
  Epoch 10 Val Loss: 0.0291
  Epoch 15 Val Loss: 0.0285
  Epoch 20 Val Loss: 0.0247
  Epoch 2

In [2]:
submission_df.head()

Unnamed: 0,sample_id,target
0,ID1001187975__Dry_Green_g,24.814009
1,ID1001187975__Dry_Dead_g,28.725051
2,ID1001187975__Dry_Clover_g,0.354886
3,ID1001187975__GDM_g,22.247827
4,ID1001187975__Dry_Total_g,50.685926
