In [None]:

import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt


In [None]:
# =====================================================
# Dataset ADE20K Indoor  (filtré pour robot)
# =====================================================
class ADE20KIndoorRobotDataset(Dataset):
    """
    Dataset ADE20K Indoor filtré pour un robot domestique.
    Classes conservées :
    0: fond, 1: mur, 2: sol, 3: table, 4: chaise, 5: lit
    """

    class_map = {
        0: 0,    
        7: 1,    
        11: 2,   
        12: 3,   
        13: 4,   
        14: 5    
    }

    def __init__(self, images_root, masks_root, image_size=(224, 224)):
        self.images_root = images_root
        self.masks_root = masks_root

        self.images = sorted(os.listdir(images_root))
        self.masks  = sorted(os.listdir(masks_root))

        self.image_transform = T.Compose([
            T.Resize(image_size),
            T.ToTensor()
        ])

        self.mask_transform = T.Compose([
            T.Resize(image_size, interpolation=Image.NEAREST)
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path  = os.path.join(self.images_root, self.images[idx])
        mask_path = os.path.join(self.masks_root, self.masks[idx])

        image = Image.open(img_path).convert("RGB")
        mask  = Image.open(mask_path)

        image = self.image_transform(image)
        mask  = self.mask_transform(mask)
        mask  = np.array(mask)

        mask_mapped = np.zeros_like(mask)
        for orig, new in self.class_map.items():
            mask_mapped[mask == orig] = new

        mask = torch.tensor(mask_mapped, dtype=torch.long)

        return image, mask


In [None]:
# =====================================================
# DataLoaders (Kaggle paths)
# =====================================================
train_dataset = ADE20KIndoorRobotDataset(
    images_root="/kaggle/input/ade20k-subset/indoor/images/training",
    masks_root="/kaggle/input/ade20k-subset/indoor/annotations/training"
)

val_dataset = ADE20KIndoorRobotDataset(
    images_root="/kaggle/input/ade20k-subset/indoor/images/validation",
    masks_root="/kaggle/input/ade20k-subset/indoor/annotations/validation"
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=4)


In [None]:
# =====================================================
# Vision Transformer (encodeur simple)
# =====================================================
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.proj = nn.Conv2d(
            in_chans, embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

In [None]:
class SimpleViT(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        embed_dim=768,
        n_layers=4,
        n_heads=8,
        mlp_dim=1024
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(
            img_size, patch_size, in_chans, embed_dim
        )

        self.num_patches = (img_size // patch_size) ** 2

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(
            torch.zeros(1, self.num_patches + 1, embed_dim)
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=n_heads,
            dim_feedforward=mlp_dim,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=n_layers
        )

        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.img_size = img_size

    def forward(self, x):
        B = x.size(0)

        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        x = self.encoder(x)
        x = x[:, 1:, :]

        H = W = self.img_size // self.patch_size
        x = x.transpose(1, 2).reshape(B, self.embed_dim, H, W)

        return x

In [None]:
# =====================================================
# Decoder
# =====================================================
class SimpleDecoder(nn.Module):
    def __init__(self, in_channels, n_classes):
        super().__init__()
        self.head = nn.Conv2d(in_channels, n_classes, kernel_size=1)

    def forward(self, x):
        return self.head(x)

In [None]:
# =====================================================
# Modèle complet
# =====================================================
n_classes = 6

encoder = SimpleViT()
decoder = SimpleDecoder(768, n_classes)

model = nn.Sequential(encoder, decoder)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

In [None]:
# =====================================================
# Boucle d'entraînement
# =====================================================
num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for images, masks in train_loader:
        images = images.to(device)
        masks  = masks.to(device)

        optimizer.zero_grad()

        outputs = model(images)

        masks_resized = nn.functional.interpolate(
            masks.unsqueeze(1).float(),
            size=outputs.shape[2:],
            mode="nearest"
        ).squeeze(1).long()

        loss = criterion(outputs, masks_resized)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(
        f"Epoch [{epoch+1}/{num_epochs}] "
        f"Loss: {running_loss/len(train_loader):.4f}"
    )

In [None]:
# =====================================================
# Visualisation des résultats
# =====================================================
model.eval()

images, masks = next(iter(val_loader))
images = images.to(device)

with torch.no_grad():
    outputs = model(images)

preds = torch.argmax(outputs, dim=1).cpu()

plt.figure(figsize=(12, 4))
for i in range(min(4, images.size(0))):
    plt.subplot(2, 4, i + 1)
    plt.imshow(images[i].cpu().permute(1, 2, 0))
    plt.title("Image")
    plt.axis("off")

    plt.subplot(2, 4, i + 5)
    plt.imshow(preds[i])
    plt.title("Pred Mask")
    plt.axis("off")

plt.show()


In [None]:
# ===============================
# Évaluation + Visualisation
# ===============================
import matplotlib.pyplot as plt

model.eval()
results_to_show = 14  # nombre total d'images à afficher
shown = 0  # compteur d'images affichées

with torch.no_grad():
    for images, masks in val_loader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)  # (B, n_classes, H_patch, W_patch)
        preds = torch.argmax(outputs, dim=1)  # (B, H_patch, W_patch)
        
        # Upsample pour correspondre à la taille originale du mask
        preds_up = nn.functional.interpolate(preds.unsqueeze(1).float(), size=masks.shape[1:], mode='nearest').squeeze(1).long()
        
        batch_size = images.shape[0]
        for i in range(batch_size):
            if shown >= results_to_show:
                break  # on arrête si on a affiché assez d'images

            fig, axs = plt.subplots(1, 3, figsize=(15,5))
            axs[0].imshow(images[i].cpu().permute(1,2,0))  # image originale
            axs[0].set_title("Image")
            axs[1].imshow(masks[i].cpu(), cmap='tab20')   # mask ground truth
            axs[1].set_title("Mask GT")
            axs[2].imshow(preds_up[i].cpu(), cmap='tab20') # mask prédit
            axs[2].set_title("Prediction")
            for ax in axs:
                ax.axis('off')
            plt.show()

            shown += 1

        if shown >= results_to_show:
            break  # arrêter la boucle sur les batches