In [1]:
import os
import random
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
from torchvision.models.segmentation import deeplabv3_resnet50

class RemoteSensingDataset(Dataset):
    def __init__(self, image_dir, label_dir=None, transform=None):
        assert os.path.isdir(image_dir), f"Không tìm thấy folder {image_dir}"
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transform = transform

        self.images = sorted(os.listdir(self.image_dir))
        if label_dir:
            assert os.path.isdir(label_dir), f"Không tìm thấy folder {label_dir}"
            self.labels = sorted(os.listdir(self.label_dir))
            assert len(self.images) == len(self.labels), "Số ảnh và mask phải khớp"
        else:
            self.labels = None

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)

        if self.labels:
            m_path = os.path.join(self.label_dir, self.labels[idx])
            mask = Image.open(m_path).convert("L")
            mask = mask.resize((img.shape[1], img.shape[2]), resample=Image.NEAREST)
            mask = torch.from_numpy(np.array(mask)).long()
        else:
            mask = torch.full((img.shape[1], img.shape[2]), 255, dtype=torch.long)
        return img, mask


In [2]:
# Đánh giá IoU
def mean_iou(pred, target, num_classes):
    ious = []
    for cls in range(num_classes):
        p = (pred == cls)
        t = (target == cls)
        inter = (p & t).sum().float()
        union = (p | t).sum().float()
        if union > 0:
            ious.append((inter / union).item())
    return np.mean(ious) if ious else 1.0


# Mô hình DiverseHead
class DiverseHead(nn.Module):
    def __init__(self, num_classes, num_heads=10, dropout_rate=0.3):
        super().__init__()
        self.backbone = deeplabv3_resnet50(pretrained=False, num_classes=num_classes)
        in_ch = self.backbone.classifier[0].project[0].out_channels
        self.heads = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(in_ch, 256, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Dropout(dropout_rate),
                nn.Conv2d(256, num_classes, 1)
            ) for _ in range(num_heads)
        ])
        self.num_heads = num_heads

    def forward(self, x):
        feats = self.backbone.backbone(x)['out']
        feats = self.backbone.classifier[0](feats)
        outs = []
        for head in self.heads:
            o = head(feats)
            o = F.interpolate(o, size=x.shape[2:], mode='bilinear', align_corners=False)
            outs.append(o)
        return outs

    def freeze_heads(self, frozen_ids):
        for idx, head in enumerate(self.heads):
            for p in head.parameters():
                p.requires_grad = (idx not in frozen_ids)


# Hàm tạo pseudo-label với mean voting + max voting
def generate_pseudo_labels_voting(model, dl_unlab, device, num_classes=2, threshold=0.6, phi=2.0):
    model.eval()
    pseudo_imgs, pseudo_labels = [], []
    with torch.no_grad():
        for x_u, _ in dl_unlab:
            x_u = x_u.to(device)
            outs = model(x_u)

            probs_heads = [F.softmax(o, dim=1) for o in outs]              # List of [B, C, H, W]
            probs_stack = torch.stack(probs_heads, dim=0)                 # [L, B, C, H, W]
            probs_mean = probs_stack.mean(dim=0)                          # [B, C, H, W]
            mean_label = probs_mean.argmax(dim=1)                         # [B, H, W]

            vote_counts = torch.zeros_like(probs_mean, dtype=torch.float)  # [B, C, H, W]
            for lbl in [p.argmax(dim=1) for p in probs_heads]:
                for c in range(num_classes):
                    vote_counts[:, c] += (lbl == c).float()

            for c in range(num_classes):
                vote_counts[:, c] += phi * (mean_label == c).float()

            pseudo = vote_counts.argmax(dim=1)                            # [B, H, W]
            max_probs, _ = probs_mean.max(dim=1)                          # [B, H, W]
            pseudo[max_probs < threshold] = 255                           # Loại bỏ pixel không tin tưởng

            pseudo_imgs.append(x_u.cpu())
            pseudo_labels.append(pseudo.cpu())
    return torch.cat(pseudo_imgs, dim=0), torch.cat(pseudo_labels, dim=0)


In [3]:
# Thiết bị
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2
BS = 4

# Các phép biến đổi (augmentation)
transform_train = T.Compose([
    T.Resize((256, 256)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(0.3, 0.3, 0.3, 0.1),
    T.ToTensor(),
    T.Normalize([0.5] * 3, [0.5] * 3),
])
transform_val = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize([0.5] * 3, [0.5] * 3),
])

# Đường dẫn dữ liệu
train_lab_dir = "img_test_ver2/img_test_ver2/train_have_label"
train_lab_mask = "img_test_ver2/img_test_ver2/mask_train_label"
train_unlab_dir = "img_test_ver2/img_test_ver2/train_unlabel"
val_img_dir = "img_test_ver2/img_test_ver2/val"
val_mask_dir = "img_test_ver2/img_test_ver2/val_labels"

# Dataset và DataLoader
ds_lab = RemoteSensingDataset(train_lab_dir, train_lab_mask, transform=transform_train)
ds_unlab = RemoteSensingDataset(train_unlab_dir, None, transform=transform_train)
ds_val = RemoteSensingDataset(val_img_dir, val_mask_dir, transform=transform_val)

dl_lab = DataLoader(ds_lab, BS, shuffle=True, num_workers=0, pin_memory=True)
dl_unlab = DataLoader(ds_unlab, BS, shuffle=True, num_workers=0, pin_memory=True)
dl_val = DataLoader(ds_val, BS, shuffle=False, num_workers=0, pin_memory=True)

# Khởi tạo mô hình, tối ưu hóa, loss
model = DiverseHead(num_classes=num_classes, num_heads=10, dropout_rate=0.3).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=255)




In [4]:
# Phase 1: Huấn luyện với dữ liệu gán nhãn + không gán nhãn (semi-supervised)
num_epochs = 20
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    it_lab = iter(dl_lab)
    it_unlab = iter(dl_unlab)
    steps = max(len(dl_lab), len(dl_unlab))

    for _ in range(steps):
        try:
            x_l, y_l = next(it_lab)
        except StopIteration:
            it_lab = iter(dl_lab)
            x_l, y_l = next(it_lab)

        try:
            x_u, _ = next(it_unlab)
        except StopIteration:
            it_unlab = iter(dl_unlab)
            x_u, _ = next(it_unlab)

        x_l, y_l = x_l.to(device), y_l.to(device)
        x_u = x_u.to(device)

        # Dynamic freezing: đóng băng một nửa số head
        frozen = random.sample(range(model.num_heads), model.num_heads // 2)
        model.freeze_heads(frozen)

        outs_l = model(x_l)  # supervised outputs
        outs_u = model(x_u)  # unsupervised outputs

        # Loss có label
        sup_loss = sum(criterion(o, y_l) for o in outs_l) / model.num_heads

        # Loss không có label (pseudo từ trung bình)
        probs = torch.stack([F.softmax(o, 1) for o in outs_u], 0).mean(0)
        pseudo = probs.argmax(1).detach()
        idx = random.randrange(model.num_heads)
        unsup_loss = criterion(outs_u[idx], pseudo)

        loss = sup_loss + unsup_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model.freeze_heads([])

        running_loss += loss.item()

    # Validation sau mỗi epoch
    model.eval()
    miou_sum, cnt = 0, 0
    with torch.no_grad():
        for x_v, y_v in dl_val:
            x_v, y_v = x_v.to(device), y_v.to(device)
            outs = model(x_v)
            out = torch.stack(outs, 0).mean(0).argmax(1)
            miou_sum += mean_iou(out.cpu(), y_v.cpu(), num_classes)
            cnt += 1
    val_miou = miou_sum / cnt
    print(f"[Phase1] Ep {epoch}/{num_epochs} | loss={running_loss/steps:.4f} | Val mIoU={val_miou:.4f}")

# Lưu model sau phase 1
torch.save(model.state_dict(), "diversehead_phase1.pth")


[Phase1] Ep 1/20 | loss=0.3965 | Val mIoU=0.8913
[Phase1] Ep 2/20 | loss=0.0113 | Val mIoU=0.8913
[Phase1] Ep 3/20 | loss=0.0050 | Val mIoU=0.8913
[Phase1] Ep 4/20 | loss=0.0023 | Val mIoU=0.8913
[Phase1] Ep 5/20 | loss=0.0020 | Val mIoU=0.8913
[Phase1] Ep 6/20 | loss=0.0014 | Val mIoU=0.8913
[Phase1] Ep 7/20 | loss=0.0014 | Val mIoU=0.8913
[Phase1] Ep 8/20 | loss=0.0011 | Val mIoU=0.8913
[Phase1] Ep 9/20 | loss=0.0010 | Val mIoU=0.8913
[Phase1] Ep 10/20 | loss=0.0009 | Val mIoU=0.8913
[Phase1] Ep 11/20 | loss=0.0007 | Val mIoU=0.8913
[Phase1] Ep 12/20 | loss=0.0006 | Val mIoU=0.8913
[Phase1] Ep 13/20 | loss=0.0007 | Val mIoU=0.8913
[Phase1] Ep 14/20 | loss=0.0006 | Val mIoU=0.8913
[Phase1] Ep 15/20 | loss=0.0007 | Val mIoU=0.8913
[Phase1] Ep 16/20 | loss=0.0006 | Val mIoU=0.8913
[Phase1] Ep 17/20 | loss=0.0005 | Val mIoU=0.8913
[Phase1] Ep 18/20 | loss=0.0005 | Val mIoU=0.8913
[Phase1] Ep 19/20 | loss=0.0003 | Val mIoU=0.8913
[Phase1] Ep 20/20 | loss=0.0003 | Val mIoU=0.8913


In [None]:
# Phase 2: Tạo pseudo-labels với mean + max voting
print("Generating pseudo-labels with voting...")
pseudo_imgs, pseudo_labels = generate_pseudo_labels_voting(
    model, dl_unlab, device, num_classes=num_classes, threshold=0.6, phi=2.0
)
print(f"Generated {pseudo_imgs.shape[0]} pseudo-labeled samples.")

# Lưu pseudo masks dưới dạng ảnh grayscale
save_dir = "saved_pseudo_labels"
os.makedirs(save_dir, exist_ok=True)
for i in range(pseudo_labels.shape[0]):
    mask = pseudo_labels[i].numpy().astype(np.uint8)
    m_pil = Image.fromarray(mask, mode="L")
    m_pil.save(os.path.join(save_dir, f"pseudo_mask_{i:03d}.png"))
print(f"Saved pseudo-labels to {save_dir}/")

# (Tùy chọn) Augment nhẹ để tăng robustness
pseudo_imgs = torch.clamp(pseudo_imgs + 0.01 * torch.randn_like(pseudo_imgs), -1, 1)

# Dataset từ pseudo-labels
class PseudoDataset(Dataset):
    def __init__(self, imgs, masks):
        self.imgs = imgs
        self.masks = masks
    def __len__(self): return len(self.imgs)
    def __getitem__(self, idx): return self.imgs[idx], self.masks[idx]

ds_pseudo = PseudoDataset(pseudo_imgs, pseudo_labels)
dl_pseudo = DataLoader(ds_pseudo, BS, shuffle=True, num_workers=0, pin_memory=True)

# Phase 2: Fine-tune với dữ liệu gán nhãn + pseudo-label
ft_epochs = 10
for epoch in range(1, ft_epochs + 1):
    model.train()
    running_loss = 0.0
    it_lab = iter(dl_lab)
    it_pseudo = iter(dl_pseudo)
    steps = max(len(dl_lab), len(dl_pseudo))

    for _ in range(steps):
        try:
            x_l, y_l = next(it_lab)
        except StopIteration:
            it_lab = iter(dl_lab)
            x_l, y_l = next(it_lab)
        try:
            x_p, y_p = next(it_pseudo)
        except StopIteration:
            it_pseudo = iter(dl_pseudo)
            x_p, y_p = next(it_pseudo)

        x_l, y_l = x_l.to(device), y_l.to(device)
        x_p, y_p = x_p.to(device), y_p.to(device)

        frozen = random.sample(range(model.num_heads), model.num_heads // 2)
        model.freeze_heads(frozen)

        outs_l = model(x_l)
        outs_p = model(x_p)

        sup_loss = sum(criterion(o, y_l) for o in outs_l) / model.num_heads
        pseudo_loss = sum(criterion(o, y_p) for o in outs_p) / model.num_heads

        loss = sup_loss + pseudo_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        model.freeze_heads([])

        running_loss += loss.item()

    # Validation sau mỗi epoch
    model.eval()
    miou_sum, cnt = 0, 0
    with torch.no_grad():
        for x_v, y_v in dl_val:
            x_v, y_v = x_v.to(device), y_v.to(device)
            outs = model(x_v)
            out = torch.stack(outs, 0).mean(0).argmax(1)
            miou_sum += mean_iou(out.cpu(), y_v.cpu(), num_classes)
            cnt += 1
    val_miou = miou_sum / cnt
    print(f"[Fine-tune] Ep {epoch}/{ft_epochs} | loss={running_loss/steps:.4f} | Val mIoU={val_miou:.4f}")

# Lưu model cuối cùng
out_path = "diversehead_final_selftrain.pth"
torch.save(model.state_dict(), out_path)
print(f"Finished training. Model saved to {out_path}")


Generating pseudo-labels with voting...
Generated 103 pseudo-labeled samples.
Saved pseudo-labels to saved_pseudo_labels/


In [None]:
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

model.eval()
save_pred_dir = "predictions"
os.makedirs(save_pred_dir, exist_ok=True)

with torch.no_grad():
    for i, (x, y) in enumerate(dl_val):
        x = x.to(device)
        outs = model(x)
        pred = torch.stack(outs, 0).mean(0).argmax(1)  # [B, H, W]
        for b in range(x.size(0)):
            mask = pred[b].cpu().numpy().astype(np.uint8) * 127
            out_img = Image.fromarray(mask, mode="L")
            out_img.save(os.path.join(save_pred_dir, f"pred_{i*BS + b:03d}.png"))
print(f"Saved predictions to {save_pred_dir}/")
