In [42]:
import torch
import torchvision.transforms as transforms
from torchvision.models import resnet18
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


import torchvision
import torchvision.transforms as transforms

mnist_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(3),
    transforms.ToTensor(),
])

mnist_trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=mnist_transform)
mnist_testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=mnist_transform)

svhn_transform = transforms.Compose([transforms.ToTensor(), ])
svhn_trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=svhn_transform)
svhn_testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=svhn_transform)

# # Load pre-trained ResNet and modify for feature extraction
# resnet = resnet18(pretrained=False, num_classes=10)
# resnet.load_state_dict(torch.load("path_to_mnist_resnet.pth"))
# resnet.eval()

# # Feature extraction function
# def extract_features(image_batch):
#     with torch.no_grad():
#         features = resnet(image_batch)
#     return features


Using downloaded and verified file: ./data\train_32x32.mat
Using downloaded and verified file: ./data\test_32x32.mat


In [58]:
from diffusers import UNet2DConditionModel, DDPMScheduler

# Define UNet for diffusion
unet = UNet2DConditionModel(
    sample_size=32,  # SVHN image size
    in_channels=3,  # RGB for SVHN
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(64, 128, 256),
    down_block_types=("DownBlock2D", "DownBlock2D", "AttnDownBlock2D"),
    up_block_types=("UpBlock2D", "UpBlock2D", "AttnUpBlock2D"),
    cross_attention_dim=512  # Conditioning vector size
).to(device)

# Define noise scheduler
scheduler = DDPMScheduler(num_train_timesteps=100)


In [46]:
class PairedMNISTSVHN(Dataset):
    def __init__(self, mnist_dataset, svhn_dataset):
        self.mnist_dataset = mnist_dataset
        self.svhn_dataset = svhn_dataset

        # Create a mapping from digit labels to indices for SVHN
        self.svhn_label_to_indices = {i: [] for i in range(10)}
        for idx, label in enumerate(self.svhn_dataset.labels):
            self.svhn_label_to_indices[label].append(idx)

    def __len__(self):
        # Return the size of the MNIST dataset
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        # Get an MNIST sample
        mnist_img, mnist_label = self.mnist_dataset[idx]
        
        # Get a random SVHN sample with the same label
        svhn_idx = torch.randint(0, len(self.svhn_label_to_indices[mnist_label]), (1,)).item()
        svhn_img = self.svhn_dataset[self.svhn_label_to_indices[mnist_label][svhn_idx]][0]

        return mnist_img, svhn_img, mnist_label

In [50]:
img1.shape

torch.Size([32, 3, 224, 224])

In [47]:
# Paired Dataset
paired_trainset = PairedMNISTSVHN(mnist_trainset, svhn_trainset)
paired_trainloader = DataLoader(paired_trainset, batch_size=32, shuffle=True)

# Paired Dataset
paired_testset = PairedMNISTSVHN(mnist_testset, svhn_testset)
paired_testloader = DataLoader(paired_testset, batch_size=32, shuffle=False)

In [56]:
import torch
from torchvision.models import resnet18

class ResNetFeatureExtractor(torch.nn.Module):
    def __init__(self, original_resnet):
        super(ResNetFeatureExtractor, self).__init__()
        # Retain all layers except the final FC layer
        self.features = torch.nn.Sequential(*list(original_resnet.children())[:-1])
    
    def forward(self, x):
        x = self.features(x)
        return x

# Load the ResNet model and modify it
resnet = resnet18(pretrained=False)
num_ftrs = resnet.fc.in_features
resnet.fc = torch.nn.Linear(num_ftrs, 10) # MNIST has 10 classes

for param in resnet.parameters():
    param.requires_grad = False

resnet.load_state_dict(torch.load("resnet_model.pth"))
feature_extractor = ResNetFeatureExtractor(resnet)

for param in feature_extractor.parameters():
    param.requires_grad = False

feature_extractor = feature_extractor.to(device)



# # Pass an input through the modified model
# input_image = torch.randn(1, 3, 224, 224)  # Example input
# features = feature_extractor(input_image)

# print("Extracted Features Shape:", features.shape)  # [batch_size, 512, 1, 1]

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import SVHN
from torchvision.transforms import Compose, ToTensor, Normalize

# Define optimizer and loss function
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)
criterion = torch.nn.MSELoss()

num_epochs = 5
# Training loop
for epoch in range(num_epochs):
    for batch in tqdm(paired_trainloader):
        mnist_images, svhn_images, labels = batch
        mnist_images = mnist_images.to(device)
        svhn_images = svhn_images.to(device)
        labels = labels.to(device)

        # Extract features from MNIST ResNet
        mnist_features = feature_extractor(mnist_images).to(device)

        # Generate noise and corrupted images
        noise = torch.randn_like(svhn_images).to(device)
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (svhn_images.size(0),), device=device).long()
        noisy_images = scheduler.add_noise(svhn_images, noise, timesteps)

        # Predict noise with UNet conditioned on MNIST features
        predicted_noise = unet(noisy_images, timesteps, encoder_hidden_states=mnist_features.squeeze().unsqueeze(1)).sample

        # Calculate loss and backpropagate
        loss = criterion(predicted_noise, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        torch.save(unet.state_dict(), f"models/unet_{epoch+1}_model.pth")
    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")


  1%|          | 22/1875 [01:38<2:18:13,  4.48s/it]


KeyboardInterrupt: 