In [None]:
tta_transforms = [
    transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=1.0),  # Horizontal flip
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225])
    ]),
    transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomRotation(15),  # Rotate ±15 degrees
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                            [0.229, 0.224, 0.225])
    ])
]


In [None]:
class TTASoilDataset(Dataset):
    def __init__(self, image_ids, image_dir, transforms_list):
        self.image_ids = image_ids
        self.image_dir = image_dir
        self.transforms_list = transforms_list

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_path = os.path.join(self.image_dir, img_id)
        image = Image.open(img_path).convert("RGB")

        # Apply each TTA transform and stack the results
        images = [tf(image) for tf in self.transforms_list]
        images = torch.stack(images)  # Shape: [num_tta, C, H, W]

        return img_id, images


In [None]:
# Load test image IDs (REQUIRED before preparing test dataset)
test_ids_df = pd.read_csv(TEST_IDS_CSV)
image_ids = test_ids_df["image_id"].tolist()

# Now prepare test dataset and loader
test_dataset = TTASoilDataset(image_ids, TEST_DIR, tta_transforms)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)


In [None]:
model_preds = {img_id: [] for img_id in image_ids}

for fold, model in enumerate(fold_models):
    model.eval()
    model.to(device)
    print(f"Running TTA predictions for Fold {fold+1} ...")

    with torch.no_grad():
        for img_id, images in tqdm(test_loader):
            # images shape: [batch=1, num_tta, C, H, W]
            images = images.squeeze(0).to(device)  # shape: [num_tta, C, H, W]

            # Predict for each TTA image and average
            outputs = model(images)  # [num_tta, num_classes]
            probs = torch.softmax(outputs, dim=1)
            mean_prob = probs.mean(dim=0).cpu().numpy()  # average across TTA

            model_preds[img_id[0]].append(mean_prob)


In [None]:
final_preds = []
for img_id in image_ids:
    # Stack fold predictions and average
    fold_probs = np.stack(model_preds[img_id], axis=0)
    avg_probs = np.mean(fold_probs, axis=0)
    pred_label = label_encoder.classes_[np.argmax(avg_probs)]
    final_preds.append(pred_label)


In [None]:
submission = pd.DataFrame({
    "image_id": image_ids,
    "soil_type": final_preds
})

submission.to_csv("submission.csv", index=False)
print("Submission file created: submission.csv")
