In [14]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import numpy as np
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torchvision.models import resnet18, ResNet18_Weights

In [15]:


# ---------------------------
# Config
# ---------------------------
class Config:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    img_size = 320
    batch_size = 16
    lr = 1e-4
    epochs = 10
    num_workers = 4
    num_classes = 3  # no fog, medium fog, dense fog
    SAVE_DIR = './checkpoints'
    classes = ['No_Fog', 'Medium_Fog', 'Dense_Fog']  # Match your folder names exactly

os.makedirs(Config.SAVE_DIR, exist_ok=True)

# ---------------------------
# Grad-CAM for Classification
# ---------------------------
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handle = None

        def save_activation(module, input, output):
            self.activations = output.detach()

        def save_gradient(grad):
            self.gradients = grad

        self.hook_handle = self.target_layer.register_forward_hook(save_activation)
        self.target_layer.register_backward_hook(lambda module, grad_in, grad_out: save_gradient(grad_out[0]))

    # def __call__(self, x, class_idx=None):
    #     self.model.zero_grad()
    #     output = self.model(x)
    #     if class_idx is None:
    #         class_idx = output.argmax(dim=1).item()
    #     one_hot = torch.zeros_like(output)
    #     one_hot[0][class_idx] = 1
    #     output.backward(gradient=one_hot, retain_graph=False)

    #     weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)
    #     cam = torch.sum(weights * self.activations, dim=1, keepdim=True)
    #     cam = F.relu(cam)
    #     cam = F.interpolate(cam, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
    #     cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
    #     return cam[0, 0].cpu().numpy()
    def __call__(self, x, class_idx=None):
        self.model.zero_grad()
        model_output = self.model(x)
        
        # Handle tuple output (logits, features)
        if isinstance(model_output, tuple):
            logits = model_output[0]
        else:
            logits = model_output

        if class_idx is None:
            class_idx = logits.argmax(dim=1).item()
        
        one_hot = torch.zeros_like(logits)
        one_hot[0][class_idx] = 1
        logits.backward(gradient=one_hot, retain_graph=False)

        # Compute weights via global average pooling of gradients
        weights = torch.mean(self.gradients, dim=(2, 3), keepdim=True)  # [1, C, 1, 1]
        cam = torch.sum(weights * self.activations, dim=1, keepdim=True)  # [1, 1, H, W]
        cam = F.relu(cam)
        cam = F.interpolate(cam, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        return cam[0, 0].cpu().numpy()
    def remove_hook(self):
        if self.hook_handle:
            self.hook_handle.remove()

# ---------------------------
# Dataset for Fog Classification (Accepts list of samples)
# ---------------------------
class FogClassificationDataset(Dataset):
    def __init__(self, samples, transform=None):
        """
        samples: list of (image_path, label)
        """
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f"Failed to load image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(img)
        return img, label, img_path

# ---------------------------
# Improved LTC Cell (Kept as-is)
# ---------------------------
class ImprovedLTCCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, time_steps=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.time_steps = time_steps
        self.W_ih = nn.Linear(input_dim, hidden_dim)
        self.W_hh = nn.Linear(hidden_dim, hidden_dim)
        self.tau = nn.Parameter(torch.ones(hidden_dim) * 0.5)
        self.alpha = nn.Parameter(torch.ones(hidden_dim))
        self.beta = nn.Parameter(torch.zeros(hidden_dim))

    def forward(self, x_seq, hidden=None):
        T, B, _ = x_seq.shape
        if hidden is None:
            hidden = torch.zeros(B, self.hidden_dim, device=x_seq.device)
        tau = F.softplus(self.tau) + 0.01
        dt = 1.0 / self.time_steps
        outputs = []
        for t in range(T):
            inp = x_seq[t]
            for _ in range(self.time_steps):
                activation = torch.tanh(self.alpha * (self.W_ih(inp) + self.W_hh(hidden)) + self.beta)
                dhdt = (-hidden + activation) / tau.unsqueeze(0)
                hidden = hidden + dt * dhdt
            outputs.append(hidden)
        return torch.stack(outputs, dim=0), hidden

class LiquidTemporalModule(nn.Module):
    def __init__(self, feat_dim, hidden_dim=256, num_layers=2):
        super().__init__()
        self.proj = nn.Linear(feat_dim, hidden_dim)
        self.ltc_layers = nn.ModuleList([
            ImprovedLTCCell(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])

    def forward(self, feat_seq):
        if not feat_seq:
            return torch.zeros(1, 256, device=feat_seq[0].device)
        seq = torch.stack([self.proj(f) for f in feat_seq], dim=0)
        hidden = None
        for ltc in self.ltc_layers:
            seq, hidden = ltc(seq, hidden)
        return seq[-1]

# ---------------------------
# Final Model: Fog Classifier with LTC & Dehaze
# ---------------------------
class FogClassifier(nn.Module):
    def __init__(self, num_classes=3, use_ltc=True):
        super().__init__()
        self.use_ltc = use_ltc
        self.dehaze = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 3, 1),
            nn.Sigmoid()
        )
        # Load ImageNet weights only if available locally ‚Äì no download attempt
        self.backbone = resnet18(weights=None)  # Or weights=ResNet18_Weights.DEFAULT if cached
        self.backbone.fc = nn.Identity()
        feat_dim = 512

        if use_ltc:
            self.temporal = LiquidTemporalModule(feat_dim, hidden_dim=256, num_layers=2)
            self.classifier = nn.Linear(256, num_classes)
        else:
            self.classifier = nn.Linear(feat_dim, num_classes)

        self.backbone_layer4 = self.backbone.layer4

    def forward(self, x):
        if x.dim() == 5:
            x = x.squeeze(1)
        x = self.dehaze(x)
        features = self.backbone.conv1(x)
        features = self.backbone.bn1(features)
        features = self.backbone.relu(features)
        features = self.backbone.maxpool(features)
        features = self.backbone.layer1(features)
        features = self.backbone.layer2(features)
        features = self.backbone.layer3(features)
        features = self.backbone.layer4(features)

        pooled = F.adaptive_avg_pool2d(features, (1,1)).view(x.size(0), -1)
        if self.use_ltc:
            latent = self.temporal([pooled])
        else:
            latent = pooled
        logits = self.classifier(latent)
        return logits, features

# ---------------------------
# Training Utilities
# ---------------------------
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(dataloader, desc="Training")
    for imgs, labels, _ in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits, _ = model(imgs.unsqueeze(1))
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
        pbar.set_postfix({"Loss": loss.item(), "Acc": correct/total})
    return total_loss / len(dataloader), correct / total

def validate(model, dataloader, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels, _ in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            logits, _ = model(imgs.unsqueeze(1))
            loss = F.cross_entropy(logits, labels)
            total_loss += loss.item()
            preds = logits.argmax(dim=1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)
    return total_loss / len(dataloader), correct / total

In [16]:
# ---------------------------
# Visualization with Grad-CAM ‚Äî One per class
# ---------------------------
def visualize_gradcam_per_class(model, dataloader, device, gradcam, class_names):
    model.eval()
    # Collect one example per class
    examples = {0: None, 1: None, 2: None}

    for imgs, labels, paths in dataloader:
        for i in range(len(labels)):
            label = labels[i].item()
            if examples[label] is None:
                examples[label] = (imgs[i], label)
            if all(v is not None for v in examples.values()):
                break
        if all(v is not None for v in examples.values()):
            break

    # Plot one per class
    fig, axes = plt.subplots(3, 3, figsize=(12, 12))
    fig.suptitle("Grad-CAM: One Sample Per Fog Class", fontsize=16)

    for row, (label, (img_tensor, true_label)) in enumerate(examples.items()):
        img_tensor = img_tensor.unsqueeze(0).to(device)  # [1, C, H, W]
        img_np = img_tensor[0].cpu().numpy().transpose(1, 2, 0)

        # Get prediction
        with torch.no_grad():
            logits, _ = model(img_tensor.unsqueeze(1))  # Add seq dim
            pred_class = logits.argmax(dim=1).item()

        # Generate Grad-CAM for predicted class
        cam = gradcam(img_tensor, class_idx=pred_class)

        # Original
        axes[row, 0].imshow(img_np)
        axes[row, 0].set_title(f"True: {class_names[true_label]}\nPred: {class_names[pred_class]}")
        axes[row, 0].axis('off')

        # Heatmap
        axes[row, 1].imshow(cam, cmap='jet')
        axes[row, 1].set_title("Grad-CAM")
        axes[row, 1].axis('off')

        # Overlay
        overlay = np.uint8(255 * img_np * 0.6 + cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) * 0.4)
        axes[row, 2].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
        axes[row, 2].set_title("Overlay")
        axes[row, 2].axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

In [None]:


# # ---------------------------
# # Visualization with Grad-CAM
# # ---------------------------
# def visualize_gradcam(model, dataloader, device, gradcam, num_samples=3):
#     model.eval()
#     count = 0
#     for imgs, labels, paths in dataloader:
#         if count >= num_samples:
#             break
#         imgs, labels = imgs.to(device), labels.to(device)
#         img_np = imgs[0].cpu().numpy().transpose(1, 2, 0)
#         img_tensor = imgs[0:1].unsqueeze(1)

#         with torch.no_grad():
#             logits, _ = model(img_tensor)
#             pred_class = logits.argmax(dim=1).item()

#         cam = gradcam(img_tensor.squeeze(1), class_idx=pred_class)

#         plt.figure(figsize=(12, 4))
#         plt.subplot(1, 3, 1)
#         plt.imshow(img_np)
#         plt.title(f"Original\nTrue: {Config.classes[labels[0]]}\nPred: {Config.classes[pred_class]}")
#         plt.axis('off')

#         plt.subplot(1, 3, 2)
#         plt.imshow(cam, cmap='jet')
#         plt.title("Grad-CAM")
#         plt.axis('off')

#         overlay = np.uint8(255 * img_np * 0.6 + cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET) * 0.4)
#         plt.subplot(1, 3, 3)
#         plt.imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
#         plt.title("Overlay")
#         plt.axis('off')

#         plt.tight_layout()
#         plt.show()
#         count += 1


if __name__ == "__main__":
    DATA_ROOT = "/kaggle/input/foggy-cityscapes-image-dataset/Foggy_Cityscapes"  

    # Step 1: Collect all samples
    all_samples = []
    class_to_idx = {cls: idx for idx, cls in enumerate(Config.classes)}

    for class_name in Config.classes:
        class_dir = os.path.join(DATA_ROOT, class_name)
        if not os.path.exists(class_dir):
            raise FileNotFoundError(f"Directory not found: {class_dir}")
        for fname in os.listdir(class_dir):
            if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                all_samples.append((os.path.join(class_dir, fname), class_to_idx[class_name]))

    print(f"‚úÖ Total images loaded: {len(all_samples)}")

    # Step 2: Stratified split (80% train, 20% val)
    image_paths, labels = zip(*all_samples)
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        image_paths, labels, test_size=0.2, random_state=42, stratify=labels
    )

    train_samples = list(zip(train_paths, train_labels))
    val_samples = list(zip(val_paths, val_labels))

    print(f"SplitOptions ‚Üí Train: {len(train_samples)}, Val: {len(val_samples)}")
    print(f"Train class counts: {np.bincount(train_labels)}")
    print(f"Val class counts: {np.bincount(val_labels)}")

    # Step 3: Transforms
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((Config.img_size, Config.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Step 4: Datasets & Loaders
    train_ds = FogClassificationDataset(train_samples, transform=transform)
    val_ds = FogClassificationDataset(val_samples, transform=transform)

    train_loader = DataLoader(train_ds, batch_size=Config.batch_size, shuffle=True, num_workers=Config.num_workers)
    val_loader = DataLoader(val_ds, batch_size=Config.batch_size, shuffle=False, num_workers=Config.num_workers)

    # Step 5: Model + Training
    model = FogClassifier(num_classes=Config.num_classes, use_ltc=True).to(Config.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.epochs)

    gradcam = GradCAM(model, model.backbone_layer4)

    best_acc = 0.0
    for epoch in range(Config.epochs):
        tr_loss, tr_acc = train_one_epoch(model, train_loader, optimizer, Config.device)
        val_loss, val_acc = validate(model, val_loader, Config.device)
        scheduler.step()

        print(f"Epoch {epoch+1}/{Config.epochs} - Train Loss: {tr_loss:.4f}, Acc: {tr_acc:.4f} | Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), os.path.join(Config.SAVE_DIR, "best_fog_classifier.pth"))
            print("‚úÖ Saved best model!")

    # # Step 6: Visualize
    # print("\nüîç Generating Grad-CAM visualizations...")
    # visualize_gradcam(model, val_loader, Config.device, gradcam, num_samples=3)
    # gradcam.remove_hook()

‚úÖ Total images loaded: 1500
SplitOptions ‚Üí Train: 1200, Val: 300
Train class counts: [400 400 400]
Val class counts: [100 100 100]


  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)
Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 75/75 [06:26<00:00,  5.15s/it, Loss=0.364, Acc=0.709]


Epoch 1/10 - Train Loss: 0.6720, Acc: 0.7092 | Val Loss: 1.1688, Acc: 0.4833
‚úÖ Saved best model!


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 75/75 [06:21<00:00,  5.08s/it, Loss=0.349, Acc=0.802]


Epoch 2/10 - Train Loss: 0.4813, Acc: 0.8025 | Val Loss: 0.8121, Acc: 0.6667
‚úÖ Saved best model!


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 75/75 [06:17<00:00,  5.03s/it, Loss=0.371, Acc=0.839]


Epoch 3/10 - Train Loss: 0.4171, Acc: 0.8392 | Val Loss: 0.3592, Acc: 0.8700
‚úÖ Saved best model!


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 75/75 [06:19<00:00,  5.06s/it, Loss=0.305, Acc=0.838]


Epoch 4/10 - Train Loss: 0.4172, Acc: 0.8375 | Val Loss: 0.6963, Acc: 0.6967


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 75/75 [06:18<00:00,  5.05s/it, Loss=0.277, Acc=0.877] 


Epoch 5/10 - Train Loss: 0.3007, Acc: 0.8775 | Val Loss: 0.2306, Acc: 0.9267
‚úÖ Saved best model!


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 75/75 [06:25<00:00,  5.15s/it, Loss=0.272, Acc=0.902] 


Epoch 6/10 - Train Loss: 0.2670, Acc: 0.9025 | Val Loss: 0.2297, Acc: 0.9267


Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 75/75 [06:20<00:00,  5.07s/it, Loss=0.0479, Acc=0.952]


Epoch 7/10 - Train Loss: 0.1391, Acc: 0.9517 | Val Loss: 0.1853, Acc: 0.9300
‚úÖ Saved best model!


Training:  15%|‚ñà‚ñç        | 11/75 [01:01<05:27,  5.12s/it, Loss=0.0458, Acc=0.932]