In [None]:
!cp -r /kaggle/input/sitta_official_replicate/pytorch/default/1  /kaggle/working/

In [None]:
%cd /kaggle/working/1/SITTA-Vidya

In [None]:
!ls -l

In [None]:
!pip install torch torchvision numpy pyyaml opencv-python matplotlib tqdm

In [None]:
!pip install torch-fidelity lpips

# Step 2: Dataset Preparation

In [None]:
import os
import shutil
import random
import glob

# Define dataset paths
dataset_path = "/kaggle/input/plant-pathology-sitta/images"
output_dir = "/kaggle/working/dataset"

# Create required directories
os.makedirs(f"{output_dir}/leaves/trainA", exist_ok=True)
os.makedirs(f"{output_dir}/leaves/testA", exist_ok=True)
os.makedirs(f"{output_dir}/leaves/trainB", exist_ok=True)
os.makedirs(f"{output_dir}/leaves/testB", exist_ok=True)

# Get all image files
all_images = glob.glob(f"{dataset_path}/*.jpg")
random.shuffle(all_images)

# Split into 80% train, 20% test
split_idx = int(0.8 * len(all_images))
trainA_images = all_images[:split_idx]
testA_images = all_images[split_idx:]

# Copy images to trainA/testA
for img in trainA_images:
    shutil.copy(img, f"{output_dir}/leaves/trainA/")

for img in testA_images:
    shutil.copy(img, f"{output_dir}/leaves/testA/")

# Select 200 random images as "textures" for trainB
num_textures = min(200, len(trainA_images))
trainB_images = random.sample(trainA_images, num_textures)

# Move them to trainB (simulating textures)
for img in trainB_images:
    shutil.copy(img, f"{output_dir}/leaves/trainB/")

# Move 10% of trainB to testB
testB_size = int(0.1 * len(trainB_images))
testB_images = trainB_images[:testB_size]

for img in testB_images:
    shutil.copy(img, f"{output_dir}/leaves/testB/")

print("✅ Dataset structured successfully!")

# Step 3: Resize Images

In [None]:
from PIL import Image
import os
import glob

# Define input and output directories
dataset_dirs = ["trainA", "testA", "trainB", "testB"]
base_dir = "/kaggle/working/dataset/leaves"

# Target size
TARGET_SIZE = (288, 288)

for dataset in dataset_dirs:
    input_folder = os.path.join(base_dir, dataset)
    resized_folder = os.path.join(base_dir, f"{dataset}_resized")
    os.makedirs(resized_folder, exist_ok=True)

    # Resize all images
    for img_path in glob.glob(input_folder + "/*.jpg"):
        img = Image.open(img_path).convert("RGB")  # Convert to RGB format
        img = img.resize(TARGET_SIZE, Image.BICUBIC)  # Bicubic interpolation
        img.save(os.path.join(resized_folder, os.path.basename(img_path)))

    print(f"✅ Resized images saved in {resized_folder}")

print("🔥 All images resized successfully!")

# Step 4: Update YAML

In [None]:
import yaml

# Load the YAML file
yaml_path = "/kaggle/working/1/SITTA-Vidya/configs/single2single.yaml"
with open(yaml_path, "r") as file:
    config = yaml.safe_load(file)

config["trainA_dir"] = "/kaggle/working/dataset/leaves/trainA_resized"
config["testA_dir"] = "/kaggle/working/dataset/leaves/testA_resized"
config["trainB_dir"] = "/kaggle/working/dataset/leaves/trainB_resized"
config["testB_dir"] = "/kaggle/working/dataset/leaves/testB_resized"

# Save updated YAML file
with open(yaml_path, "w") as file:
    yaml.dump(config, file)

print("✅ YAML updated successfully!")

# Step 5: SITTA Model Implementation (With PONO)
This includes:

PONO Normalization Layer
SITTA Generator (with and without PONO)
SITTA Discriminator

## 5.1: Define PONO Normalization

In [None]:
import torch
import torch.nn as nn

# Define PONO Normalization Layer
class PONO(nn.Module):
    def __init__(self):
        super(PONO, self).__init__()

    def forward(self, x):
        mean = x.mean(dim=[2, 3], keepdim=True)
        std = x.std(dim=[2, 3], keepdim=True)
        return (x - mean) / (std + 1e-5), mean, std

##  5.2: Define SITTA Generator (With & Without PONO)

In [None]:
# Define SITTA Generator with option to enable/disable PONO
class SITTA_Generator(nn.Module):
    def __init__(self, use_pono=False):
        super(SITTA_Generator, self).__init__()
        self.use_pono = use_pono
        self.pono_layer = PONO() if use_pono else None

        # Encoder: Extract features
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 7, padding=3),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.ReLU()
        )

        # Decoder: Generate texture
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 7, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        features = self.encoder(x)
        
        # Apply PONO if enabled
        if self.use_pono:
            features, _, _ = self.pono_layer(features)
        
        return self.decoder(features)

## 5.3: Define SITTA Discriminator

In [None]:
# Define SITTA Discriminator
class SITTA_Discriminator(nn.Module):
    def __init__(self):
        super(SITTA_Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 1, 4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

## 5.4: Initialize Models

In [None]:
# Initialize models
generator_without_pono = SITTA_Generator(use_pono=False).cuda()
generator_with_pono = SITTA_Generator(use_pono=True).cuda()
discriminator = SITTA_Discriminator().cuda()

print("✅ Generator (with & without PONO) and Discriminator initialized successfully!")

# Step 6: Training Function for SITTA (With & Without PONO)
✔ Includes:


* Adversarial Loss
* Cycle Consistency Loss
* Identity Loss
* Training Function
* Training Execution for both With & Without PONO

## 6.1: Define Loss Functions

In [None]:
import torch.optim as optim
import torch.nn.functional as F

# Define loss functions
adversarial_loss = nn.MSELoss()  # For GAN loss
cycle_loss = nn.L1Loss()         # Cycle consistency loss
identity_loss = nn.L1Loss()      # Identity loss

# Optimizers
optimizer_G = optim.Adam(generator_with_pono.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

print("✅ Loss functions and optimizers initialized!")

## 6.2: Define Training Function

In [None]:
import time

def train_sitta(generator, label, epochs=5):
    """
    Training function for SITTA generator.
    Runs a simple adversarial loss training loop.
    """
    print(f"🚀 Training {label} Generator...")

    for epoch in range(epochs):
        start_time = time.time()

        for real_A in trainA_loader:
            for real_B in trainB_loader:
                real_A, real_B = real_A.cuda(), real_B.cuda()
                fake_B = generator(real_A)

                # Compute loss
                adv_loss = adversarial_loss(discriminator(fake_B), torch.ones_like(discriminator(fake_B)))
                cyc_loss = cycle_loss(generator(fake_B), real_A)
                idt_loss = identity_loss(generator(real_B), real_B)
                gen_loss = adv_loss + 10 * cyc_loss + 5 * idt_loss

                # Optimize Generator
                optimizer_G.zero_grad()
                gen_loss.backward()
                optimizer_G.step()

        end_time = time.time() - start_time
        print(f"Epoch [{epoch+1}/{epochs}] - Gen Loss: {gen_loss.item():.4f} - Time: {end_time:.2f}s")

    print(f"✅ Training Complete for {label}!")

## 6.3: Train Both Models (With & Without PONO)

In [None]:
# Train without PONO
train_sitta(generator_without_pono, "Without PONO")

# Train with PONO
train_sitta(generator_with_pono, "With PONO")

print("✅ Training completed for both versions!")

# Step 7: Evaluation Metrics for SITTA
✔ Includes:

* FID (Fréchet Inception Distance)
* LPIPS (Learned Perceptual Image Patch Similarity)
* VGG Loss
* Visualization of Results

## 7.1: Define FID Calculation

In [None]:
from torch_fidelity import calculate_metrics

def compute_fid(real_images, generated_images):
    """
    Compute FID between real and generated images.
    """
    metrics = calculate_metrics(
        input1=real_images,
        input2=generated_images,
        fid=True,
        cuda=torch.cuda.is_available()
    )
    return metrics["frechet_inception_distance"]

##  7.2: Define LPIPS Calculation

In [None]:
import lpips

lpips_model = lpips.LPIPS(net='alex').to(device)

def compute_lpips(real_images, generated_images):
    """
    Compute LPIPS score.
    """
    return lpips_model(real_images, generated_images).mean().item()

## 7.3: Update VGG Loss

In [None]:
import torchvision.models as models

class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = models.vgg19(pretrained=True).features[:16].eval()
        for param in self.vgg.parameters():
            param.requires_grad = False
        self.criterion = nn.L1Loss()

    def forward(self, generated, target):
        generated_features = self.vgg(generated)
        target_features = self.vgg(target).detach()
        return self.criterion(generated_features, target_features)

##  7.4: Evaluate Trained Models

In [None]:
# Generate images for evaluation
test_real_A = next(iter(trainA_loader)).cuda()[:8]  # Get a batch of real images
test_fake_B = generator_with_pono(test_real_A).detach()  # Generate textures

# Compute evaluation metrics
fid_score = compute_fid(test_real_A, test_fake_B)
lpips_score = compute_lpips(test_real_A, test_fake_B)

print(f"📊 FID Score: {fid_score:.2f}")
print(f"📊 LPIPS Score: {lpips_score:.3f}")

## 7.5: Visualizing Results

In [None]:
import matplotlib.pyplot as plt

def visualize_results(real_A, fake_B):
    """
    Displays a side-by-side comparison of real images and their generated textures.
    """
    real_A = real_A.permute(0, 2, 3, 1).cpu().numpy()
    fake_B = fake_B.permute(0, 2, 3, 1).cpu().numpy()

    fig, axes = plt.subplots(2, len(real_A), figsize=(12, 5))
    
    for i in range(len(real_A)):
        axes[0, i].imshow(real_A[i])
        axes[0, i].axis("off")
        axes[1, i].imshow(fake_B[i])
        axes[1, i].axis("off")

    axes[0, 0].set_title("Original Shape")
    axes[1, 0].set_title("Generated Texture")
    plt.show()

# Display a batch of results
visualize_results(test_real_A, test_fake_B)