# Advanced Biomass Prediction: TTA Inference
Uses trained Student models with TTA (4 rotations).

In [None]:
CONFIG = {
    'backbone': 'convnextv2_tiny.fcmae_ft_in22k_in1k',
    'img_size': 512,
    'batch_size': 4,
    'weights': ['student_fold0.pth', 'student_fold1.pth', 'student_fold2.pth'],
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'data_dir': r'..\csiro-biomass',
    'target_scale': 100.0
}

In [None]:

import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False)
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        return self.act(self.bn(self.pointwise(self.depthwise(x))))

class BiFPNLayer(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.nodes = nn.ModuleList([DepthwiseSeparableConv(channels, channels) for _ in range(3)])

    def forward(self, p3, p4, p5):
        p4_td = self.nodes[0](p4 + F.interpolate(p5, size=p4.shape[-2:]))
        p3_out = self.nodes[1](p3 + F.interpolate(p4_td, size=p3.shape[-2:]))
        p4_out = self.nodes[2](p4_td + F.interpolate(p3_out, size=p4_td.shape[-2:]))
        return p3_out, p4_out, p5

class BiomassModel(nn.Module):
    def __init__(self, model_name, feature_dim=256):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, features_only=True)
        feature_info = self.backbone.feature_info.get_dicts()
        self.align_p3 = nn.Conv2d(feature_info[-3]['num_chs'], feature_dim, 1)
        self.align_p4 = nn.Conv2d(feature_info[-2]['num_chs'], feature_dim, 1)
        self.align_p5 = nn.Conv2d(feature_info[-1]['num_chs'], feature_dim, 1)
        self.bifpn = BiFPNLayer(feature_dim)
        self.total_head = nn.Sequential(nn.Linear(feature_dim, 1), nn.Softplus())
        self.comp_head = nn.Sequential(nn.Linear(feature_dim, 4), nn.Softplus())
        self.aux_height = nn.Sequential(nn.Linear(feature_dim, 1))
        self.aux_ndvi = nn.Sequential(nn.Linear(feature_dim, 1), nn.Sigmoid())

    def forward(self, x):
        b, t, c, h, w = x.shape
        x = x.view(b * t, c, h, w)
        feats = self.backbone(x)
        p3 = self.align_p3(feats[-3]); p4 = self.align_p4(feats[-2]); p5 = self.align_p5(feats[-1])
        p3, p4, p5 = self.bifpn(p3, p4, p5)
        def pool(f): 
            f = f.view(b, t, *f.shape[1:]).mean(dim=1)
            return nn.AdaptiveAvgPool2d(1)(f).flatten(1)
        f3, f4, f5 = pool(p3), pool(p4), pool(p5)
        return self.total_head(f4), self.comp_head(f4), self.aux_height(f5), self.aux_ndvi(f5), f4


In [None]:
class InferenceDataset(Dataset):
    def __init__(self, df, img_dir, rotate=0):
        self.df = df
        self.img_dir = img_dir
        self.rotate = rotate
        self.tf = A.Compose([A.Resize(512, 512), A.Normalize(), ToTensorV2()])
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(os.path.join(self.img_dir, row['image_path'])).convert('RGB')
        img = np.array(img)
        if self.rotate > 0: img = np.rot90(img, k=self.rotate)
        h, w, _ = img.shape
        tiles = [img[0:512,0:512], img[0:512,w-512:w], img[h-512:h,0:512], img[h-512:h,w-512:w]]
        return torch.stack([self.tf(image=t)['image'] for t in tiles]), row['image_path']

In [None]:
def run_inference():
    test_df = pd.read_csv(os.path.join(CONFIG['data_dir'], 'test.csv'))
    unique_images = test_df[['image_path']].drop_duplicates()
    target_names = ['Dry_Total_g', 'GDM_g', 'Dry_Green_g', 'Dry_Dead_g', 'Dry_Clover_g']
    
    final_preds = {}
    for w_path in CONFIG['weights']:
        if not os.path.exists(w_path): continue
        print(f'Running inference for {w_path}...')
        model = BiomassModel(CONFIG['backbone']).to(CONFIG['device'])
        model.load_state_dict(torch.load(w_path, map_location=CONFIG['device']))
        model.eval()
        
        with torch.no_grad():
            for rot in [0, 1, 2, 3]:
                ds = InferenceDataset(unique_images, CONFIG['data_dir'], rotate=rot)
                loader = DataLoader(ds, batch_size=CONFIG['batch_size'])
                for x, paths in loader:
                    with torch.amp.autocast('cuda' if 'cuda' in CONFIG['device'] else 'cpu'):
                        tp, cp, _, _, _ = model(x.to(CONFIG['device']))
                    p = (torch.cat([tp, cp], dim=1) * CONFIG['target_scale']).cpu().numpy()
                    for i in range(len(paths)):
                        path = paths[i]
                        if path not in final_preds: final_preds[path] = p[i] / (len(CONFIG['weights']) * 4)
                        else: final_preds[path] += p[i] / (len(CONFIG['weights']) * 4)
    
    sub_rows = []
    for path, p in final_preds.items():
        img_id = os.path.basename(path).replace('.jpg', '')
        for i, name in enumerate(target_names):
            sub_rows.append({'sample_id': f'{img_id}__{name}', 'target': p[i]})
    pd.DataFrame(sub_rows).to_csv('submission.csv', index=False)

run_inference()