In [1]:


   # Imports
import json
import math
import heapq
from collections import deque
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
            
 
  

In [2]:
import os
import shutil
import csv
import json
import heapq
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
from torchvision import transforms, models, datasets

# PATH CONFIGURATION 
if Path("/kaggle/input").exists():
    BASE_INPUT = Path("/kaggle/input/the-blind-flight-synapse-drive-ps-1/SynapseDrive_Dataset")
    WORKING_DIR = Path("/kaggle/working")
else:
    BASE_INPUT = Path("SynapseDrive_Dataset") 
    WORKING_DIR = Path(".")

TRAIN_IMG_DIR = BASE_INPUT / "train/images"
TRAIN_LABEL_DIR = BASE_INPUT / "train/labels" 
TEST_IMG_DIR = BASE_INPUT / "test/images"
TEST_VEL_DIR = BASE_INPUT / "test/velocities"

TILES_DIR = WORKING_DIR / "train_dataset_tiles"
MODEL_RESNET = WORKING_DIR / "tile_classifier_resnet.pth"
MODEL_EFFNET = WORKING_DIR / "tile_classifier_effnet.pth"
SUBMISSION_FILE = WORKING_DIR / "submission.csv"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 8 
LR = 0.0005

#1. DATA PREPARATION 
def step_1_prepare_data():
    print(f"\n[1/4] Preparing Data (Slicing Maps from JSON)...")
    if TILES_DIR.exists():
        shutil.rmtree(TILES_DIR)
    
    # Creating folders for classes 0, 1, 2, 3, 4
    for i in range(5):
        (TILES_DIR / str(i)).mkdir(parents=True, exist_ok=True)
        
    image_files = sorted(list(TRAIN_IMG_DIR.glob("*.png")))
    
    tile_count = 0
    for img_path in tqdm(image_files):
        # Finding corresponding JSON label file
        json_path = TRAIN_LABEL_DIR / f"{img_path.stem}.json"
        
        if not json_path.exists():
            print(f"Warning: No label found for {img_path.name}, skipping.")
            continue
            
        # Loading Image
        img = Image.open(img_path).convert("RGB")
        w, h = img.size
        
        # Loading JSON Grid
        with open(json_path, 'r') as f:
            data = json.load(f)
            grid = np.array(data["grid"]) # 20x20
        
        # Calculating tile size
        grid_rows, grid_cols = grid.shape
        tile_w, tile_h = w // grid_cols, h // grid_rows
        
        # Slicing and Saving
        for r in range(grid_rows):
            for c in range(grid_cols):
                left, top = c * tile_w, r * tile_h
                right, bottom = left + tile_w, top + tile_h
                
                # Cropping Image Tile
                tile_img = img.crop((left, top, right, bottom))
                
                # Getting Label directly from the JSON grid
                label = grid[r, c]
                
                # Saving
                save_path = TILES_DIR / str(label) / f"{img_path.stem}_{r}_{c}.png"
                tile_img.save(save_path)
                tile_count += 1
                
    print(f"Data Prepared: {tile_count} tiles generated.")

#2. TRAINING  

def train_model(model_name, save_path):
    print(f"\n[Training] {model_name}...")
    
    # Transformations
    train_tf = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    val_tf = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Datasets
    full_ds = datasets.ImageFolder(TILES_DIR, transform=train_tf)
    train_len = int(0.8 * len(full_ds))
    val_len = len(full_ds) - train_len
    train_ds, val_ds = random_split(full_ds, [train_len, val_len])
    val_ds.dataset.transform = val_tf
    
    # Weights
    targets = [full_ds.targets[i] for i in train_ds.indices]
    counts = np.bincount(targets)
    weights = 1. / np.maximum(counts, 1)
    sample_weights = [weights[t] for t in targets]
    sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
    
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    # Model Selection
    if model_name == "resnet":
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, 5)
    else: # efficientnet
        model = models.efficientnet_b0(pretrained=True)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, 5)
        
    model = model.to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    criterion = nn.CrossEntropyLoss()
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
    
    best_acc = 0.0
    
    for epoch in range(EPOCHS):
        model.train()
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        acc = correct / total
        scheduler.step(acc)
        print(f"  Ep {epoch+1}: Acc {acc:.4f}")
        
        if acc >= best_acc:
            best_acc = acc
            torch.save(model.state_dict(), save_path)
            
    print(f"Saved {model_name} with Acc: {best_acc:.4f}")

# 3. SUBMISSION (Ensemble + TTA + GPS Fix)

def step_4_generate_submission():
    print(f"\n[4/4] Generating Ensemble Submission...")
    
    # 1. Loading Brain 1 (ResNet)
    m1 = models.resnet18(pretrained=False)
    m1.fc = nn.Linear(m1.fc.in_features, 5)
    m1.load_state_dict(torch.load(MODEL_RESNET, map_location=DEVICE))
    m1.to(DEVICE).eval()
    
    # 2. Loading Brain 2 (EffNet)
    m2 = models.efficientnet_b0(pretrained=False)
    m2.classifier[1] = nn.Linear(m2.classifier[1].in_features, 5)
    m2.load_state_dict(torch.load(MODEL_EFFNET, map_location=DEVICE))
    m2.to(DEVICE).eval()
    
    # 3. Setting up Dataset
    class TestDS(Dataset):
        def __init__(self, d):
            self.files = sorted(list(d.glob("*.png")))
            self.tf = transforms.Compose([
                transforms.Resize((224, 224)), transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        def __len__(self): return len(self.files)
        def __getitem__(self, idx):
            p = self.files[idx]
            img = Image.open(p).convert("RGB")
            # TTA: Normal + Flip
            tiles = []
            for im in [img, img.transpose(Image.FLIP_LEFT_RIGHT)]:
                w, h = im.size
                tw, th = w // 20, h // 20
                for r in range(20):
                    for c in range(20):
                        tile = im.crop((c*tw, r*th, (c+1)*tw, (r+1)*th))
                        tiles.append(self.tf(tile))
            return p.stem, torch.stack(tiles)

    ds = TestDS(TEST_IMG_DIR)
    dl = DataLoader(ds, batch_size=1, shuffle=False, num_workers=2)
    
    # 4. Biome & Dijkstra Logic
    BIOME_COSTS = {
        "lab": {0:1.0, 1:9999.0, 2:3.0, 3:1.0, 4:2.0},
        "forest": {0:1.5, 1:9999.0, 2:2.8, 3:1.5, 4:2.5},
        "desert": {0:1.2, 1:9999.0, 2:3.7, 3:1.2, 4:2.2}
    }
    
    results = []
    
    with torch.no_grad():
        for fid, batch in tqdm(dl):
            curr_id = fid[0]
            
            # Biome (20x20 Logic)
            raw = Image.open(TEST_IMG_DIR / f"{curr_id}.png").convert("RGB")
            sm = raw.resize((20,20))
            pix = list(sm.getdata())
            g_votes = sum(1 for r,g,b in pix if g > r+15 and g > b+15)
            y_score = sum((r+g)-(2*b) for r,g,b in pix)
            biome = "lab"
            if g_votes > 10: biome = "forest"
            elif y_score/400 > 40: biome = "desert"
            costs = BIOME_COSTS[biome]
            
            # Ensemble Prediction
            inp = batch.squeeze(0).to(DEVICE) # [800, 3, 224, 224]
            out1 = torch.softmax(m1(inp), dim=1).cpu().numpy()
            out2 = torch.softmax(m2(inp), dim=1).cpu().numpy()
            avg = (out1 + out2) / 2.0
            
            # TTA Merge(avg of both)
            p_norm = avg[:400].reshape(20, 20, 5)
            p_flip = avg[400:].reshape(20, 20, 5)
            p_final = (p_norm + np.fliplr(p_flip)) / 2.0
            
            grid = np.argmax(p_final, axis=2)
            
            # GPS Fix
            if 3 not in grid: grid[divmod(np.argmax(p_final[:,:,3]), 20)] = 3
            if 4 not in grid:
                 r, c = divmod(np.argmax(p_final[:,:,4]), 20)
                 if grid[r,c] != 3: grid[r,c] = 4
            
            # Dijkstra
            vel_p = TEST_VEL_DIR / f"{curr_id}.json"
            v_grid = np.array(json.load(open(vel_p))["boost"]) if vel_p.exists() else np.zeros((20,20))
            
            starts = np.argwhere(grid==3)
            goals = np.argwhere(grid==4)
            start = tuple(starts[0]) if len(starts)>0 else (0,0)
            goal = tuple(goals[0]) if len(goals)>0 else (19,19)
            
            pq = [(0, start[0], start[1], "")]
            vis = {}
            path_str = "r" # default
            
            while pq:
                c, r, c_idx, p = heapq.heappop(pq)
                if (r, c_idx) == goal:
                    path_str = p
                    break
                if (r, c_idx) in vis and vis[(r, c_idx)] <= c: continue
                vis[(r, c_idx)] = c
                
                for dr, dc, char in [(-1,0,'u'), (1,0,'d'), (0,-1,'l'), (0,1,'r')]:
                    nr, nc = r+dr, c_idx+dc
                    if 0<=nr<20 and 0<=nc<20:
                        ct = grid[nr, nc]
                        base = costs.get(ct, 1.0)
                        boost = v_grid[nr][nc]
                        cost_step = max(0.01, base - boost)
                        heapq.heappush(pq, (c + cost_step, nr, nc, p + char))
            
            results.append([curr_id, path_str])
            
    with open(SUBMISSION_FILE, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["image_id", "path"])
        writer.writerows(results)
    print(f"DONE! Saved to {SUBMISSION_FILE}")

# MAIN EXECUTION

if __name__ == "__main__":
    step_1_prepare_data()
    
    print("\n[2/4] Training Brain 1 (ResNet)...")
    train_model("resnet", MODEL_RESNET)
    
    print("\n[3/4] Training Brain 2 (EfficientNet)...")
    train_model("efficientnet", MODEL_EFFNET)
    
    step_4_generate_submission()


[1/4] Preparing Data (Slicing Maps from JSON)...


100%|██████████| 20/20 [00:10<00:00,  1.94it/s]
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


Data Prepared: 8000 tiles generated.

[2/4] Training Brain 1 (ResNet)...

[Training] resnet...


100%|██████████| 44.7M/44.7M [00:00<00:00, 177MB/s]


  Ep 1: Acc 1.0000
  Ep 2: Acc 1.0000
  Ep 3: Acc 1.0000
  Ep 4: Acc 1.0000
  Ep 5: Acc 1.0000
  Ep 6: Acc 1.0000
  Ep 7: Acc 1.0000
  Ep 8: Acc 1.0000
Saved resnet with Acc: 1.0000

[3/4] Training Brain 2 (EfficientNet)...

[Training] efficientnet...


Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 184MB/s]


  Ep 1: Acc 1.0000
  Ep 2: Acc 1.0000
  Ep 3: Acc 1.0000
  Ep 4: Acc 1.0000
  Ep 5: Acc 1.0000
  Ep 6: Acc 1.0000
  Ep 7: Acc 1.0000
  Ep 8: Acc 1.0000
Saved efficientnet with Acc: 1.0000

[4/4] Generating Ensemble Submission...


100%|██████████| 10000/10000 [5:36:08<00:00,  2.02s/it] 

DONE! Saved to /kaggle/working/submission.csv



