In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.models import vgg16, VGG16_Weights
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from PIL import Image
import numpy as np
import os
import glob
import matplotlib.pyplot as plt

# Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)

In [None]:
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNetGenerator, self).__init__()
        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True)
            )
        
        self.encoder1 = conv_block(in_channels, 64)
        self.encoder2 = conv_block(64, 128)
        self.encoder3 = conv_block(128, 256)
        self.pool = nn.MaxPool2d(2, 2)
        
        self.bottleneck = conv_block(256, 512)
        
        self.upconv3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.decoder3 = conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.decoder2 = conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.decoder1 = conv_block(128, 64)
        
        self.final_conv = nn.Conv2d(64, out_channels, 1)
    
    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(self.pool(e1))
        e3 = self.encoder3(self.pool(e2))
        b = self.bottleneck(self.pool(e3))
        d3 = self.upconv3(b)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.decoder3(d3)
        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.decoder2(d2)
        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.decoder1(d1)
        return torch.tanh(self.final_conv(d1))

In [None]:
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=1):
        super(PatchDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, 4, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.model(x)

In [None]:
class LossFunctions:
    def __init__(self):
        self.bce = nn.BCELoss()
        self.l1 = nn.L1Loss()
        self.vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features.to(device).eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
    
    def adversarial_loss(self, pred, target):
        return self.bce(pred, target)
    
    def pixel_loss(self, pred, target):
        return self.l1(pred, target)
    
    def perceptual_loss(self, pred, target):
        pred_features = self.vgg(pred.repeat(1, 3, 1, 1))
        target_features = self.vgg(target.repeat(1, 3, 1, 1))
        return self.l1(pred_features, target_features)
    
    def style_loss(self, pred, target):
        def gram_matrix(x):
            b, c, h, w = x.size()
            x = x.view(b, c, h * w)
            return torch.bmm(x, x.transpose(1, 2)) / (c * h * w)
        
        pred_gram = gram_matrix(self.vgg(pred.repeat(1, 3, 1, 1)))
        target_gram = gram_matrix(self.vgg(target.repeat(1, 3, 1, 1)))
        return self.l1(pred_gram, target_gram)

In [None]:
class MedicalImageDataset(Dataset):
    def __init__(self, infected_dir, gt_dir, transform=None):
        self.infected_dir = infected_dir
        self.gt_dir = gt_dir
        self.transform = transform
        
        # Tải danh sách file
        self.infected_files = sorted(glob.glob(os.path.join(infected_dir, "*.png")))
        self.gt_files = sorted(glob.glob(os.path.join(gt_dir, "*.png")))
        
        # Kiểm tra dataset
        if not self.infected_files or not self.gt_files:
            raise ValueError(f"No images found in {infected_dir} or {gt_dir}")
        if len(self.infected_files) != len(self.gt_files):
            raise ValueError(f"Mismatch in number of images: {len(self.infected_files)} in infected, {len(self.gt_files)} in ground truth")
        
        # Kiểm tra tên file có khớp không
        for inf, gt in zip(self.infected_files, self.gt_files):
            if os.path.basename(inf) != os.path.basename(gt):
                raise ValueError(f"File names do not match: {inf} vs {gt}")
        
        print(f"Loaded {len(self.infected_files)} image pairs")
    
    def __len__(self):
        return len(self.infected_files)
    
    def __getitem__(self, idx):
        try:
            infected_img = Image.open(self.infected_files[idx]).convert('L')
            gt_img = Image.open(self.gt_files[idx]).convert('L')
        except FileNotFoundError as e:
            raise FileNotFoundError(f"Error loading image: {self.infected_files[idx]} or {self.gt_files[idx]}") from e
        
        if self.transform:
            infected_img = self.transform(infected_img)
            gt_img = self.transform(gt_img)
        
        return infected_img, gt_img, os.path.basename(self.infected_files[idx])

In [None]:
def save_images(infected_img, restored_img, gt_img, filename, output_dir):
    # Detach để ngắt gradient tracking trước khi chuyển sang
    infected_img = (infected_img * 0.5 + 0.5).clamp(0, 1).detach().squeeze().cpu().numpy()
    restored_img = (restored_img * 0.5 + 0.5).clamp(0, 1).detach().squeeze().cpu().numpy()
    gt_img = (gt_img * 0.5 + 0.5).clamp(0, 1).detach().squeeze().cpu().numpy()
    
    sample_dir = os.path.join(output_dir, "samples")
    os.makedirs(sample_dir, exist_ok=True)
    
    base_name = os.path.splitext(filename)[0]
    Image.fromarray((infected_img * 255).astype(np.uint8)).save(
        os.path.join(sample_dir, f"{base_name}_infected.png"))
    Image.fromarray((restored_img * 255).astype(np.uint8)).save(
        os.path.join(sample_dir, f"{base_name}_restored.png"))
    Image.fromarray((gt_img * 255).astype(np.uint8)).save(
        os.path.join(sample_dir, f"{base_name}_ground_truth.png"))

In [None]:
def train_model(generator, disc1, disc2, dataloader, num_epochs=300, save_samples=5):
    loss_fn = LossFunctions()
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(list(disc1.parameters()) + list(disc2.parameters()), lr=0.0002, betas=(0.5, 0.999))
    
    lambda_adv, lambda_pixel, lambda_perc, lambda_style = 1.0, 100.0, 10.0, 250.0
    best_g_loss = float('inf')
    
    for epoch in range(num_epochs):
        total_g_loss = 0.0
        total_d_loss = 0.0
        num_batches = 0
        
        for i, (infected, gt, filename) in enumerate(dataloader):
            infected, gt = infected.to(device), gt.to(device)
            
            # Huấn luyện Discriminators
            d_optimizer.zero_grad()
            fake = generator(infected)
            
            real_d1 = disc1(gt)
            fake_d1 = disc1(fake.detach())
            real_d2 = disc2(gt)
            fake_d2 = disc2(fake.detach())
            
            d_loss = (loss_fn.adversarial_loss(real_d1, torch.ones_like(real_d1)) +
                      loss_fn.adversarial_loss(fake_d1, torch.zeros_like(fake_d1)) +
                      loss_fn.adversarial_loss(real_d2, torch.ones_like(real_d2)) +
                      loss_fn.adversarial_loss(fake_d2, torch.zeros_like(fake_d2))) / 4
            
            d_loss.backward()
            d_optimizer.step()
            
            # Huấn luyện Generator
            g_optimizer.zero_grad()
            fake = generator(infected)
            fake_d1 = disc1(fake)
            fake_d2 = disc2(fake)
            
            g_adv_loss = (loss_fn.adversarial_loss(fake_d1, torch.ones_like(fake_d1)) +
                          loss_fn.adversarial_loss(fake_d2, torch.ones_like(fake_d2))) / 2
            g_pixel_loss = loss_fn.pixel_loss(fake, gt)
            g_perc_loss = loss_fn.perceptual_loss(fake, gt)
            g_style_loss = loss_fn.style_loss(fake, gt)
            
            g_loss = (lambda_adv * g_adv_loss + lambda_pixel * g_pixel_loss +
                      lambda_perc * g_perc_loss + lambda_style * g_style_loss)
            
            g_loss.backward()
            g_optimizer.step()
            
            total_g_loss += g_loss.item()
            total_d_loss += d_loss.item()
            num_batches += 1
            
            # Lưu ảnh mẫu cho một số batch đầu tiên
            if i < save_samples:
                save_images(infected[0], fake[0], gt[0], filename[0], output_dir)
            
            if i % 10 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i+1}/{len(dataloader)}] "
                      f"D Loss: {d_loss.item():.4f} G Loss: {g_loss.item():.4f}")
        
        # Tính mất mát trung bình
        avg_g_loss = total_g_loss / num_batches
        avg_d_loss = total_d_loss / num_batches
        print(f"Epoch [{epoch+1}/{num_epochs}] Avg D Loss: {avg_d_loss:.4f} Avg G Loss: {avg_g_loss:.4f}")
        
        # Lưu mô hình nếu mất mát Generator cải thiện
        if avg_g_loss < best_g_loss:
            best_g_loss = avg_g_loss
            torch.save(generator.state_dict(), os.path.join(output_dir, "best_generator.pth"))
            print(f"Saved best model at epoch {epoch+1} with G Loss: {avg_g_loss:.4f}")
    
    # Lưu mô hình cuối cùng
    torch.save(generator.state_dict(), os.path.join(output_dir, "final_generator.pth"))
    return generator

In [None]:
infected_dir = "/kaggle/input/tumor-injection-attack/ATTACKED"  
gt_dir = "/kaggle/input/axial-mri-norm"  

# Kiểm tra thư mục tồn tại
if not os.path.exists(infected_dir) or not os.path.exists(gt_dir):
    raise FileNotFoundError(f"Dataset directories not found: {infected_dir} or {gt_dir}")

# Transform cho ảnh
transform = Compose([
    Resize((128, 128)),  # Thay đổi kích thước thành 256x256
    ToTensor(),
    Normalize(mean=[0.5], std=[0.5])
])

# Tạo dataset và dataloader
try:
    dataset = MedicalImageDataset(infected_dir, gt_dir, transform)
except ValueError as e:
    print(f"Error creating dataset: {e}")
    exit(1)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# Khởi tạo mô hình
generator = UNetGenerator(in_channels=1, out_channels=1).to(device)
disc1 = PatchDiscriminator(in_channels=1).to(device)
disc2 = PatchDiscriminator(in_channels=1).to(device)

# Huấn luyện
generator = train_model(generator, disc1, disc2, dataloader, num_epochs=200, save_samples=5)

# Testing

In [None]:
def save_and_display_images(infected_img, restored_img, filename, output_dir):
    # Denormalize và chuyển sang numpy
    infected_img = (infected_img * 0.5 + 0.5).clamp(0, 1).squeeze().cpu().numpy()
    restored_img = (restored_img * 0.5 + 0.5).clamp(0, 1).squeeze().cpu().numpy()
    
    # Tạo thư mục đầu ra
    os.makedirs(output_dir, exist_ok=True)
    
    # Lưu ảnh
    base_name = os.path.splitext(os.path.basename(filename))[0]
    Image.fromarray((infected_img * 255).astype(np.uint8)).save(
        os.path.join(output_dir, f"{base_name}_infected.png"))
    Image.fromarray((restored_img * 255).astype(np.uint8)).save(
        os.path.join(output_dir, f"{base_name}_restored.png"))
    print(f"Images saved to {output_dir}: {base_name}_infected.png, {base_name}_restored.png")
    
    # Hiển thị ảnh
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(infected_img, cmap='gray')
    axes[0].set_title("Infected Image")
    axes[0].axis('off')
    axes[1].imshow(restored_img, cmap='gray')
    axes[1].set_title("Restored Image")
    axes[1].axis('off')
    plt.show()

In [None]:
weights_path = "/kaggle/working/output/best_generator.pth"  
infected_image_path = "/kaggle/input/tumor-injection-attack/ATTACKED/sub-BrainAge000129_T1w_axial_center.png"  # Thay bằng đường dẫn thực
output_dir = "/kaggle/working/output/restored"

if not os.path.exists(weights_path):
    raise FileNotFoundError(f"Weights file not found: {weights_path}")
if not os.path.exists(infected_image_path):
    raise FileNotFoundError(f"Image file not found: {infected_image_path}")

transform = Compose([
    Resize((128, 128)),  
    ToTensor(),
    Normalize(mean=[0.5], std=[0.5])
])

# Tải và tiền xử lý ảnh
infected_img = Image.open(infected_image_path).convert('L')
infected_img = transform(infected_img).unsqueeze(0).to(device)

# Khởi tạo và tải mô hình
generator = UNetGenerator(in_channels=1, out_channels=1).to(device)
generator.load_state_dict(torch.load(weights_path, map_location=device))
generator.eval()

# Phục hồi ảnh
with torch.no_grad():
    restored_img = generator(infected_img)

save_and_display_images(infected_img, restored_img, infected_image_path, output_dir)