In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install terratorch

In [None]:
import albumentations
from albumentations.pytorch import ToTensorV2
import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint
from terratorch.datamodules import GenericNonGeoSegmentationDataModule
from terratorch.tasks import SemanticSegmentationTask
import zipfile
import glob
import os
import numpy as np
import rasterio
import pandas as pd

In [None]:
dataset_path = '/kaggle/input/track1-india-al-impact-gen-ai-hackathon/track1-india-al-impact-gen-ai-hackathon/track1-india-al-impact-gen-ai-hackathon/'

datamodule = GenericNonGeoSegmentationDataModule(
    batch_size=4,
    num_workers=2,
    num_classes=6,

    # Define dataset paths
    train_data_root=dataset_path+'/train/inputs',
    train_label_data_root=dataset_path+'/train/labels',
    val_data_root=dataset_path+'/val/inputs',
    val_label_data_root=dataset_path+'/val/labels',
    test_data_root=dataset_path+'/test/inputs',
    # test_label_data_root=dataset_path+'/test/labels',
    
    #data set path for infereencing on test input
    predict_data_root=dataset_path+'test/inputs',
 
    # Define splits
    train_split="/kaggle/input/track1-india-al-impact-gen-ai-hackathon/track1-india-al-impact-gen-ai-hackathon/track1-india-al-impact-gen-ai-hackathon/train.txt",
    val_split="/kaggle/input/track1-india-al-impact-gen-ai-hackathon/track1-india-al-impact-gen-ai-hackathon/track1-india-al-impact-gen-ai-hackathon/val.txt",
    test_split="/kaggle/input/track1-india-al-impact-gen-ai-hackathon/track1-india-al-impact-gen-ai-hackathon/track1-india-al-impact-gen-ai-hackathon/test.txt",

    img_grep='*input.tif',
    label_grep='*label_c6.tif',

    train_transform=[
        albumentations.D4(),
        ToTensorV2(),
    ],
    val_transform=None,
    test_transform=None,
    means = [43.377114, 38.762922, 37.587551, 39.397895, 42.61577, 54.785745, 63.259959, 59.998601, 13.367036, 69.212995, 48.322503, 69.708629],
    stds = [3.335747, 4.160813, 5.434037, 9.239101, 8.014329, 6.745426, 8.070073, 7.844921, 2.563382, 16.967517, 15.586694, 9.258978],
    no_data_replace=0,
    no_label_replace=-1,
)


datamodule.setup("fit")
datamodule.setup("predict")

In [None]:

print("Train set size:", len(datamodule.train_dataset))
print("Validation set size:", len(datamodule.val_dataset))

datamodule.setup("test")
print("Test set size:", len(datamodule.test_dataset))

In [None]:
# plotting a few samples
datamodule.val_dataset.plot(datamodule.val_dataset[10])
#print(datamodule.val_dataset[10])

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

OUT_DIR = "/kaggle/working"

BATCH_SIZE = 16
EPOCHS = 1
LR = 1e-5
WEIGHT_DECAY = 0.1
HEAD_DROPOUT = 0.1
FREEZE_BACKBONE = False

BANDS =[1,2,3,4,5,6,7,8,9,10,11,12]
NUM_FRAMES = 1

#      Crop     Pixel_total    Pixel_percentage Class_weight
#      Gram        2545             1.73       0.4761
#     Maize        7128             4.84       0.1702
#   Mustard       36362            24.67       0.0334
# Sugarcane        4542             3.08       0.2674
#     Wheat       59585            40.42       0.0204
# OtherCrop       37247            25.27       0.0326

CLASS_WEIGHTS = [0.4761, 0.1702, 0.0334, 0.2674, 0.0204, 0.0326]

SEED = 0
pl.seed_everything(SEED)

In [None]:
SEED = 0

pl.seed_everything(SEED)

# Logger
logger = TensorBoardLogger(
    save_dir=OUT_DIR,
    name="cropid",
)

# Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor="val/Multiclass_Jaccard_Index",
    mode="max",
    dirpath=os.path.join(OUT_DIR, "cropid", "checkpoints"),
    filename="best-checkpoint-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,
)

In [None]:
# Trainer
trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices="auto",
    precision="16-mixed",
    num_nodes=1,
    logger=logger,
    max_epochs=EPOCHS,
    check_val_every_n_epoch=1,
    log_every_n_steps=10,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],
    limit_predict_batches=1,  # predict only in the first batch for generating plots
)

# DataModule
data_module = datamodule


# Model

backbone_args = dict(
    backbone_pretrained=True,
    backbone="prithvi_eo_v2_300_tl", # prithvi_eo_v2_300, prithvi_eo_v2_300_tl, prithvi_eo_v2_600, prithvi_eo_v2_600_tl
    #backbone_coords_encoding=["time", "location"],
    backbone_bands=BANDS,
    backbone_num_frames=1, # 1 is the default value,
)

decoder_args = dict(
    decoder="UperNetDecoder",
    decoder_channels=256,
    decoder_scale_modules=True
)

necks = [
    dict(
            name="ReshapeTokensToImage",
            effective_time_dim=NUM_FRAMES,
        )
    ]

model_args = dict(
    **backbone_args,
    **decoder_args,
    num_classes=6,
    head_dropout=HEAD_DROPOUT,
    necks=necks,
    rescale=True
)
    

model = SemanticSegmentationTask(
    model_args=model_args,
    plot_on_val=False,
    class_weights=CLASS_WEIGHTS,
    loss="dice",
    lr=LR,
    optimizer="AdamW",
    optimizer_hparams=dict(weight_decay=WEIGHT_DECAY),
    freeze_backbone=FREEZE_BACKBONE,
    freeze_decoder=False,
    model_factory="EncoderDecoderFactory",
    ignore_index = -1
)


In [None]:

trainer = pl.Trainer(
    accelerator="auto",
    strategy="auto",
    devices="auto",
    precision="16-mixed",
    num_nodes=1,
    logger=logger,
    max_epochs=EPOCHS,
    check_val_every_n_epoch=1,
    log_every_n_steps=10,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],
    limit_predict_batches=1,
)

trainer.fit(model, datamodule=datamodule)


In [None]:

checkpoint_dir = "/kaggle/working/cropid/checkpoints"
checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "best-checkpoint-*.ckpt"))

if len(checkpoint_files) == 0:
    raise FileNotFoundError("No best checkpoint file found in the directory.")
    
# Use the first match
best_ckpt_path = checkpoint_files[0]
# test_results = trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)
# print(test_results)
print(best_ckpt_path)


In [None]:
# A. Saving predictions on test input images  
 
# Prediction on predict_data_root (see dataloader part)
preds = trainer.predict(model, datamodule=datamodule, ckpt_path=best_ckpt_path)

output_dataset_path = "/kaggle/working/"
# Output directory to save prediction tif files
output_dir = os.path.join(output_dataset_path, "test_pred")
os.makedirs(output_dir, exist_ok=True)
 
# Loop over batches
for batch_idx, batch in enumerate(preds):
    tensor = batch[0][0]    # shape [4, 256, 256] → 4 images in this batch
    file_paths = batch[1]   # list of file paths for this batch
 
    for i in range(tensor.shape[0]):  # loop over each file in batch
        arr = tensor[i].cpu().numpy().astype('int32')  # shape [256, 256] → single band
 
        ref_path = file_paths[i]
        with rasterio.open(ref_path) as src_ref:
            ref_crs = src_ref.crs
            ref_transform = src_ref.transform
 
        out_name = os.path.splitext(os.path.basename(ref_path))[0] + "_pred.tif"
        out_path = os.path.join(output_dir, out_name)
 
        with rasterio.open(
            out_path,
            "w",
            driver="GTiff",
            height=arr.shape[0],
            width=arr.shape[1],
            count=1,
            dtype=arr.dtype,
            crs=ref_crs,
            transform=ref_transform
        ) as dst:
            dst.write(arr, 1)
 
        print(f"Saved {out_path}")

In [None]:
## This is about model prediction on test input images and then create prediction for submission
 
# A. Saving predictions on test input images  
# B. Generating the submission file (`prediction.csv`) using the mask and prediction TIFFs  

# Input directories
dir_pred = os.path.join(output_dataset_path, "test_pred")
dir_mask = os.path.join(dataset_path, "test/labels")
prediction_csv = os.path.join(output_dataset_path, "prediction.csv")
 
records = []
 
# Get prediction files
pred_files = glob.glob(os.path.join(dir_pred, "*.tif"))
 
for pred_file in pred_files:
    base_name = os.path.basename(pred_file)
    mask_file = os.path.join(dir_mask, base_name.replace("input_pred.tif", "label_mask.tif"))
    
    if not os.path.exists(mask_file):
        print(f"No mask found for {base_name}, skipping")
        continue
 
    # Read prediction
    with rasterio.open(pred_file) as src_pred:
        pred_data = src_pred.read(1)
 
    # Read mask
    with rasterio.open(mask_file) as src_mask:
        mask_data = src_mask.read(1)
 
    # Apply mask (only keep pixels where mask == 1)
    valid_idx = np.where(mask_data == 1)
 
    # Pixel IDs
    pixel_ids = np.ravel_multi_index(valid_idx, pred_data.shape)
 
    # Predictions on masked pixels
    masked_preds = pred_data[valid_idx]
 
    # For each class (0–5)
    for cls in range(6):
        # Select pixels belonging to this class
        cls_pixels = pixel_ids[masked_preds == cls]
 
        if cls_pixels.size > 0:
            # Format as "pixel_id class" pairs
            pred_str = " ".join(f"{pid} {cls}" for pid in cls_pixels.tolist())
        else:
            pred_str = ""
 
        # Remove "_input_pred.tif" and add class as suffix
        file_id = base_name.replace("_input_pred.tif", f"_{cls}")
 
        records.append([file_id, pred_str])
 
# Create df
df = pd.DataFrame(records, columns=["id", "label"])
                  
# fill empty/NA with "0"
df["label"] = df["label"].fillna("0")

# Save CSV
# df.to_csv(prediction_csv, index=False, sep="\t")
df.to_csv(prediction_csv, index=False)

print(f"Saved predictions to {prediction_csv}")