In [None]:
!pip install timm
!pip install comet_ml

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm


# --------------------------------------------
# Residual Block with optional upsampling
# --------------------------------------------
class ResidualBlockTranspose(nn.Module):
    def __init__(self, in_channels, out_channels, upsample=False):
        super().__init__()
        self.upsample = upsample

        self.conv1 = nn.ConvTranspose2d(
            in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1
        ) if upsample else nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        if upsample or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=1, stride=2, output_padding=1)
                if upsample else nn.Conv2d(in_channels, out_channels, kernel_size=1),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return self.relu(out + identity)


# --------------------------------------------
# Encoder using timm (ResNet)
# --------------------------------------------
class ResNetEncoder(nn.Module):
    def __init__(self, model_name="resnet18", pretrained=True):
        super().__init__()
        self.encoder = timm.create_model(model_name, pretrained=pretrained, features_only=True)
        self.out_channels = self.encoder.feature_info[-1]['num_chs']

    def forward(self, x):
        # Only return last feature map
        return self.encoder(x)[-1]  # e.g. shape (B, 512, 16, 16) for 512×512 input


# --------------------------------------------
# Decoder: Upsample back to 512×512
# --------------------------------------------
class ResNetDecoder(nn.Module):
    def __init__(self, in_channels, out_channels=3):
        super().__init__()
        self.decoder = nn.Sequential(
            ResidualBlockTranspose(in_channels, 256, upsample=True),  # 16 → 32
            ResidualBlockTranspose(256, 128, upsample=True),          # 32 → 64
            ResidualBlockTranspose(128, 64, upsample=True),           # 64 → 128
            ResidualBlockTranspose(64, 32, upsample=True),            # 128 → 256
            ResidualBlockTranspose(32, 16, upsample=True),            # 256 → 512
            nn.Conv2d(16, out_channels, kernel_size=3, padding=1),
            nn.Sigmoid()  # assume input is in [0, 1]
        )

    def forward(self, x):
        return self.decoder(x)


# --------------------------------------------
# Autoencoder model
# --------------------------------------------
class ResNetAutoencoder(nn.Module):
    def __init__(self, model_name="resnet18"):
        super().__init__()
        self.encoder = ResNetEncoder(model_name)
        self.decoder = ResNetDecoder(self.encoder.out_channels)

    def forward(self, x):
        z = self.encoder(x)
        out = self.decoder(z)
        return out




In [None]:
from torch.utils.data import Dataset
from PIL import Image
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        """
        Args:
            root_dir (str): Path to folder containing images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = root_dir
        self.image_paths = [os.path.join(root_dir, fname) 
                            for fname in os.listdir(root_dir) 
                            if fname.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tif'))]
        
        # self.image_paths = self.image_paths[0:201] 
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")  # Ensure 3 channels

        if self.transform:
            image = self.transform(image)

        return image,""

In [None]:
def show_reconstructions(model, dataloader, device,epoch):
    model.eval()
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            outputs = model(images)
            break

    images = images.cpu().numpy()
    outputs = outputs.cpu().numpy()

    n = min(6, images.shape[0])
    plt.figure(figsize=(12, 4))
    for i in range(n):
        plt.subplot(2, n, i + 1)
        plt.imshow(images[i].transpose(1, 2, 0))
        plt.title("Original")
        plt.axis("off")

        plt.subplot(2, n, i + n + 1)
        plt.imshow(outputs[i].transpose(1, 2, 0))
        plt.title("Reconstructed")
        plt.axis("off")
        
    image_path = f"/kaggle/working/reconstructed_image_{epoch}.png"
    plt.savefig(image_path)
    return image_path
    # experiment.log_image(image_path)
    

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import os  # Assuming your model code is in model.py
import matplotlib.pyplot as plt
from comet_ml import start
from torch.utils.data import random_split
# -----------------------------
# Configuration
# -----------------------------
experiment = start(
  api_key="clEyXjBpSkvfrD5bYxCf3vVK9",
  project_name="diabetic-retinopathy_unsupervised",
  workspace="neloy-sarwar",
)

data_dir = "/kaggle/input/diabetic-retinopathy-test-resized-512/test_images_resized_512"  # should contain subfolders if using ImageFolder
batch_size = 32
num_epochs = 20
lr = 3e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
save_dir = "/kaggle/working/checkpoints"
os.makedirs(save_dir, exist_ok=True)

# -----------------------------
# Data transforms and loader
# -----------------------------
transform = transforms.Compose([
    # transforms.Resize((512, 512)),
    transforms.ToTensor(),
])

dataset = ImageDataset(root_dir=data_dir, transform=transform)

# Define split ratio
val_split = 0.2  # 20% for validation
val_size = int(len(dataset) * val_split)
train_size = len(dataset) - val_size

# Split dataset
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
# -----------------------------
# Initialize model, loss, optimizer
# -----------------------------
model = ResNetAutoencoder("resnet18").to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# -----------------------------
# Training loop
# -----------------------------
# best_loss = float('inf')  # Initialize best loss

# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0

#     for images, _ in train_loader:
#         images = images.to(device)

#         # Forward pass
#         outputs = model(images)
#         loss = criterion(outputs, images)

#         # Backward pass and optimization
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         print(f"Epoch [{epoch+1}/{num_epochs}] Batch loss:{loss.item()}")
#         running_loss += loss.item() * images.size(0)

#     epoch_loss = running_loss / len(train_loader.dataset)
#     print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
#     experiment.log_metric("epoch_train_loss", epoch_loss, step=epoch)
#     try:
#         # Save the best model
#         if epoch_loss < best_loss:
#             best_loss = epoch_loss
#             torch.save({
#                 'epoch': epoch,
#                 'model_state_dict': model.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'loss': best_loss
#             }, os.path.join(save_dir, "best_autoencoder.pt"))
#             print(f"✅ Best model saved at epoch {epoch+1} with loss {best_loss:.4f}")
        
#         # Show reconstructions every 5 epochs
#         if (epoch + 1) % 5 == 0:
#             image_path = show_reconstructions(model, val_loader, device,epoch)
#             experiment.log_image(image_path)
#     except:
#         pass

best_loss = float('inf')
epochs_no_improve = 0
patience = 5  # stop after 5 non-improving epochs
early_stop = False

for epoch in range(num_epochs):
    if early_stop:
        break

    model.train()
    running_loss = 0.0

    for images, _ in train_loader:
        images = images.to(device)

        outputs = model(images)
        loss = criterion(outputs, images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(f"Epoch [{epoch+1}/{num_epochs}] Batch loss:{loss.item()}")
        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
    experiment.log_metric("epoch_train_loss", epoch_loss, step=epoch)

    # Evaluate on validation set
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images in val_loader:
            if isinstance(images, (tuple, list)):
                images = images[0]
            images = images.to(device)
            outputs = model(images)
            loss = criterion(outputs, images)
            val_loss += loss.item() * images.size(0)

    val_loss /= len(val_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Val Loss: {val_loss:.4f}")
    experiment.log_metric("epoch_val_loss", val_loss, step=epoch)

    # Check for improvement
    if val_loss < best_loss:
        best_loss = val_loss
        epochs_no_improve = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss
        }, os.path.join(save_dir, "best_autoencoder.pt"))
        print(f"✅ Best model saved at epoch {epoch+1} with val loss {best_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"⚠️ No improvement for {epochs_no_improve} epoch(s).")

    # Show reconstructions every 5 epochs
    if (epoch + 1) % 5 == 0:
        try:
            image_path = show_reconstructions(model, val_loader, device, epoch)
            experiment.log_image(image_path)
        except:
            pass

    # Early stopping condition
    if epochs_no_improve >= patience:
        print(f"⏹️ Early stopping triggered at epoch {epoch+1}")
        early_stop = True
