In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import geopandas as gpd
import rasterio
import numpy as np
from rasterio.mask import mask as rio_mask
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# Configuration
IMAGE_PATHS = [
    "T34TCR_20230613T094041_TCI_10m.jp2",
    "T34TCR_20230623T094031_TCI_10m.jp2",
    "T34TCR_20230703T094041_TCI_10m.jp2",
    "T34TCR_20230708T093549_TCI_10m.jp2",
    "T34TCR_20230713T094041_TCI_10m.jp2",
    "T34TCR_20230906T093549_TCI_10m.jp2"
]

SHAPEFILE_WITH_CLASSES = "celo4_with_class.shp"
NUM_BANDS = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Update the number of classes to include class 20
NUM_CLASSES = 8  # 0â€“6 + 20 (index 7)

# Dataset Class
def load_dataset():
    class CropTimeSeriesDataset(Dataset):
        def __init__(self, shp_file, image_paths):
            super().__init__()
            self.gdf = gpd.read_file(shp_file)

            # Keep all classes including 20
            self.geometries = self.gdf["geometry"].tolist()
            self.labels = self.gdf["major_clas"].tolist()
            self.image_paths = image_paths

        def __len__(self):
            return len(self.geometries)

        def __getitem__(self, idx):
            geom = self.geometries[idx]
            label = self.labels[idx]
            reflectances = []
            for img_path in self.image_paths:
                with rasterio.open(img_path) as src:
                    out_image, _ = rio_mask(src, [geom], crop=True, nodata=255)
                    band_means = []
                    for b in range(out_image.shape[0]):
                        band_data = out_image[b]
                        valid_pixels = band_data[band_data != 255]
                        mean_val = float(np.mean(valid_pixels)) if len(valid_pixels) > 0 else 0.0
                        band_means.append(mean_val)
                    reflectances.append(band_means)
            reflectances = np.array(reflectances, dtype=np.float32)
            reflectances_tensor = torch.from_numpy(reflectances)

            # Map class 20 to index 7
            label = 7 if label == 20 else label  # Adjust labels for loss function
            label_tensor = torch.tensor(label, dtype=torch.long)
            return reflectances_tensor, label_tensor

    return CropTimeSeriesDataset(SHAPEFILE_WITH_CLASSES, IMAGE_PATHS)


# LSTM Model
def create_model():
    class CropLSTM(nn.Module):
        def __init__(self, input_size=NUM_BANDS, hidden_size=128, num_classes=NUM_CLASSES):
            super(CropLSTM, self).__init__()
            self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2, dropout=0.3, batch_first=True)
            self.fc = nn.Linear(hidden_size, num_classes)

        def forward(self, x):
            out, (hn, cn) = self.lstm(x)
            last_out = out[:, -1, :]
            logits = self.fc(last_out)
            return logits

    return CropLSTM(input_size=NUM_BANDS, hidden_size=128, num_classes=NUM_CLASSES)


# Training Logic
def train_model():
    dataset = load_dataset()

    # Split dataset into train and test sets
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    model = create_model().to(DEVICE)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

    epochs = 5
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for i, (seq, lbl) in enumerate(train_loader):
            seq, lbl = seq.to(DEVICE), lbl.to(DEVICE)
            optimizer.zero_grad()
            logits = model(seq)
            loss = criterion(logits, lbl)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        scheduler.step()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    # Evaluate on test set
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for seq, lbl in test_loader:
            seq, lbl = seq.to(DEVICE), lbl.to(DEVICE)
            logits = model(seq)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(lbl.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {acc:.4f}")
    print("Confusion Matrix:")
    print(confusion_matrix(all_labels, all_preds))
    print("Classification Report:")
    print(classification_report(all_labels, all_preds))


# Execute Training
train_model()
