<a href="https://colab.research.google.com/github/Parv-Agarwal/Internship-project/blob/main/SFDA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [21]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import os
import numpy as np
from PIL import Image


In [22]:
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import os

class MNISTDataset(Dataset):
    def __init__(self, file_name, max_load=None, transform=None):
        # Load the dataset from the given file
        self.transform = transform
        self.data = []
        self.labels = []

        # Load data
        if 'train' in file_name:
            dataset = datasets.MNIST(root='./data', train=True, download=True)
        else:
            dataset = datasets.MNIST(root='./data', train=False, download=True)

        self.data = dataset.data
        self.labels = dataset.targets

        # Limit the number of examples if max_load is specified
        if max_load is not None and max_load > 0 and max_load < len(self.data):
            self.data = self.data[:max_load]
            self.labels = self.labels[:max_load]
            print(f'<mnist> loading only {max_load} examples')

        print('<mnist> done')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img = self.data[index]
        label = self.labels[index]

        # Convert tensor to PIL Image
        img = transforms.ToPILImage()(img)

        if self.transform:
            img = self.transform(img)
        else:
            img = transforms.ToTensor()(img)

        # Convert label to one-hot encoding
        label_one_hot = torch.zeros(10)
        label_one_hot[label] = 1.0

        return img, label_one_hot


In [23]:
# dataset_mnistM.py

import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import os
from torchvision import transforms

class MNISTMDataset(Dataset):
    def __init__(self, file_name, max_load=None, transform=None):
        # Load the MNIST-M dataset from the given file
        # Assuming the dataset is stored in .npy files or a custom format
        self.transform = transform
        self.data = []
        self.labels = []

        # Load data from file_name
        # For this example, we'll assume data is stored in .pt files
        # Replace this with the actual data loading code
        if not os.path.isfile(file_name):
            raise FileNotFoundError(f"File {file_name} not found.")

        data_dict = torch.load(file_name)

        # Check if data_dict is a tuple and convert to dict if needed
        if isinstance(data_dict, tuple):
            # Assuming the tuple has data and labels in the first two positions
            self.data = data_dict[0]
            self.labels = data_dict[1]
        else:
            self.data = data_dict['data']
            self.labels = data_dict['labels']

        n_example = self.data.size(0)
        print(f'nExample {n_example}')

        # Limit the number of examples if max_load is specified
        if max_load is not None and max_load > 0 and max_load < n_example:
            n_example = max_load
            print(f'<mnistM> loading only {n_example} examples')
            self.data = self.data[:n_example]
            self.labels = self.labels[:n_example]

        print('<mnistM> done')

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img = self.data[index]
        label = self.labels[index]

        if isinstance(img, np.ndarray):
            img = Image.fromarray(img.astype(np.uint8))
        elif torch.is_tensor(img):
            if img.dim() == 3:
                # If the image has shape (H, W, C), permute it to (C, H, W)
                img = img.permute(2, 0, 1)
            elif img.dim() == 2:
                # If the image has shape (H, W), add a channel dimension
                img = img.unsqueeze(0)
            img = transforms.ToPILImage()(img)

        if self.transform:
            img = self.transform(img)
        else:
            img = transforms.ToTensor()(img)

        # Convert label to one-hot encoding
        label_one_hot = torch.zeros(10)
        label_one_hot[label] = 1.0

        return img, label_one_hot


In [24]:
class LogSumExp(nn.Module):
    def __init__(self):
        super(LogSumExp, self).__init__()

    def forward(self, input):
        max_val, _ = torch.max(input, dim=1, keepdim=True)
        output = input - max_val
        output = max_val + torch.log(torch.sum(torch.exp(output), dim=1, keepdim=True))
        return output


In [25]:
opt = {
    'dataset': 'mnist',
    'batchSize': 64,
    'loadSize': 33,
    'fineSize': 32,
    'nz': 100,               # # of dim for Z
    'ngf': 64,               # # of gen filters in first conv layer
    'ndf': 64,               # # of discrim filters in first conv layer
    'nThreads': 4,           # # of data loading threads to use
    'niter': 10000,          # # of iter at starting learning rate
    'lr': 0.0002,            # initial learning rate for adam
    'beta1': 0.5,            # momentum term of adam
    'ntrain': float('inf'),  # # of examples per epoch
    'display': 0,            # display samples while training
    'display_id': 0,         # display window id
    'gpu': 1,                # gpu = 0 is CPU mode. gpu=X is GPU mode on GPU X
    'name': 'Logfiles',
    'noise': 'normal',       # 'uniform' or 'normal'
    'epoch_save_modulo': 1,
    'manual_seed': 4,        # Seed
    'nc': 3,                 # # of channels in input
    'save': 'logs/',         # Directory to save logs
    'data_root': './data',   # Root directory for datasets
    'lamda': 1,              # Lambda value for GRL
    'baseLearningRate': 0.0002,
    'max_epoch': 10000,
    'gamma': 0.001,
    'power': 0.75,
    'max_epoch_grl': 10000,
    'alpha': 10,
}

train_gen_epoch = 25

In [26]:
# Set random seed
import random
random.seed(opt['manual_seed'])
torch.manual_seed(opt['manual_seed'])
torch.set_num_threads(1)

if torch.cuda.is_available() and opt['gpu'] > 0:
    torch.cuda.manual_seed_all(opt['manual_seed'])
    device = torch.device(f'cuda:{opt["gpu"] - 1}')
else:
    device = torch.device('cpu')

print(f"Random Seed: {opt['manual_seed']}")
print(f"Device: {device}")

# Initialize data loaders
transform_mnist = transforms.Compose([
    transforms.Resize(opt['fineSize']),
    transforms.Grayscale(num_output_channels=3),  # Converts grayscale to 3 channels
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Adjusted for 3 channels
])

transform_mnistM = transforms.Compose([
    transforms.Resize(opt['fineSize']),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

Random Seed: 4
Device: cpu


In [27]:
mnist_train_path = 'mnist_train.pt'  # Adjust the path as needed
mnist_test_path = 'mnist_test.pt'    # Adjust the path as needed
max_train_load = None  # Set to None or an integer value
max_test_load = None   # Set to None or an integer value

mnist_train_dataset = MNISTDataset(mnist_train_path, max_load=max_train_load, transform=transform_mnist)
mnist_test_dataset = MNISTDataset(mnist_test_path, max_load=max_test_load, transform=transform_mnist)

mnist_train_loader = DataLoader(mnist_train_dataset, batch_size=opt['batchSize'], shuffle=True, num_workers=opt['nThreads'])
mnist_test_loader = DataLoader(mnist_test_dataset, batch_size=opt['batchSize'], shuffle=False, num_workers=opt['nThreads'])

print(f"MNIST Dataset: Size: {len(mnist_train_dataset)}")

# Load MNIST-M dataset
mnistm_train_path = 'mnist_m_train.pt'  # Adjust the path as needed
mnistm_test_path = 'mnist_m_test.pt'    # Adjust the path as needed
Num_Train_Target = 59001
Num_Test_Target = 10001

mnistm_train_dataset = MNISTMDataset(mnistm_train_path, max_load=Num_Train_Target, transform=transform_mnistM)
mnistm_test_dataset = MNISTMDataset(mnistm_test_path, max_load=Num_Test_Target, transform=transform_mnistM)

mnistm_train_loader = DataLoader(mnistm_train_dataset, batch_size=opt['batchSize'], shuffle=True, num_workers=opt['nThreads'])
mnistm_test_loader = DataLoader(mnistm_test_dataset, batch_size=opt['batchSize'], shuffle=False, num_workers=opt['nThreads'])

print(f"MNIST-M Dataset: Size: {len(mnistm_train_dataset)}")

<mnist> done
<mnist> done
MNIST Dataset: Size: 60000


  data_dict = torch.load(file_name)


nExample 60000
<mnistM> loading only 59001 examples
<mnistM> done
nExample 10000
<mnistM> done
MNIST-M Dataset: Size: 59001


In [28]:
# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [29]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input is Z + one-hot class vector
            nn.ConvTranspose2d(nz + 10, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # State size: (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # State size: (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # State size: (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # Output size: (nc) x 32 x 32
        )

    def forward(self, input):
        return self.main(input)

In [30]:
class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input is (nc) x 32 x 32
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),  # (ndf) x 16 x 16
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 4, 4, 2, 1, bias=False),  # (ndf*4) x 8 x 8
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),  # (ndf*8) x 4 x 4
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),  # Output is single value
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1)

In [31]:
from torch.autograd import Function
class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_, None

def grad_reverse(x, lambda_=1.0):
    return GradientReversalFunction.apply(x, lambda_)

In [32]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 5),  # Input channels, output channels, kernel size
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 48, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)
        )
        self.fc_features = 48 * 5 * 5  # Calculate the output size

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, self.fc_features)
        return x


In [33]:
class ClassClassifier(nn.Module):
    def __init__(self):
        super(ClassClassifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(48 * 5 * 5, 100),
            nn.ReLU(True),
            nn.Linear(100, 100),
            nn.ReLU(True),
            nn.Linear(100, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = self.fc(x)
        return x

In [34]:
class DomainClassifier(nn.Module):
    def __init__(self, lambda_=1.0):
        super(DomainClassifier, self).__init__()
        self.lambda_ = lambda_
        self.fc = nn.Sequential(
            nn.Linear(48 * 5 * 5, 100),
            nn.ReLU(True),
            nn.Linear(100, 2)
        )

    def forward(self, x):
        x = grad_reverse(x, self.lambda_)
        x = self.fc(x)
        return x

In [36]:
from torchsummary import summary

netG = Generator(opt['nz'], opt['ngf'], opt['nc']).to(device)
netD = Discriminator(opt['nc'], opt['ndf']).to(device)
feature_extractor = FeatureExtractor().to(device)
class_classifier = ClassClassifier().to(device)
domain_classifier = DomainClassifier(lambda_=opt['lamda']).to(device)

netG.apply(weights_init)
netD.apply(weights_init)
feature_extractor.apply(weights_init)
class_classifier.apply(weights_init)
domain_classifier.apply(weights_init)

# print model summaries:

print("Generator Model Summary:")
summary(netG, input_size=(opt['nz'] + 10, 1, 1))

print("\nDiscriminator Model Summary:")
summary(netD, input_size=(opt['nc'], opt['fineSize'], opt['fineSize']))

print("\nFeature Extractor Model Summary:")
summary(feature_extractor, input_size=(opt['nc'], opt['fineSize'], opt['fineSize']))

print("\nClass Classifier Model Summary:")
summary(class_classifier, input_size=(48 * 5 * 5,))

print("\nDomain Classifier Model Summary:")
summary(domain_classifier, input_size=(48 * 5 * 5,))



# Loss functions
adversarial_loss = nn.BCELoss().to(device)
classification_loss = nn.NLLLoss().to(device)
cross_entropy_loss = nn.CrossEntropyLoss().to(device)
log_sum_exp = LogSumExp().to(device)

# Optimizers
optimizer_G = optim.Adam(netG.parameters(), lr=opt['lr'], betas=(opt['beta1'], 0.999))
optimizer_D = optim.Adam(netD.parameters(), lr=opt['lr'], betas=(opt['beta1'], 0.999))
optimizer_feature = optim.SGD(feature_extractor.parameters(), lr=opt['baseLearningRate'], momentum=0.9)
optimizer_class = optim.SGD(class_classifier.parameters(), lr=opt['baseLearningRate'], momentum=0.9)
optimizer_domain = optim.SGD(domain_classifier.parameters(), lr=opt['baseLearningRate'], momentum=0.9)

Generator Model Summary:
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1            [-1, 512, 4, 4]         901,120
       BatchNorm2d-2            [-1, 512, 4, 4]           1,024
              ReLU-3            [-1, 512, 4, 4]               0
   ConvTranspose2d-4            [-1, 256, 8, 8]       2,097,152
       BatchNorm2d-5            [-1, 256, 8, 8]             512
              ReLU-6            [-1, 256, 8, 8]               0
   ConvTranspose2d-7          [-1, 128, 16, 16]         524,288
       BatchNorm2d-8          [-1, 128, 16, 16]             256
              ReLU-9          [-1, 128, 16, 16]               0
  ConvTranspose2d-10            [-1, 3, 32, 32]           6,144
             Tanh-11            [-1, 3, 32, 32]               0
Total params: 3,530,496
Trainable params: 3,530,496
Non-trainable params: 0
----------------------------------------------------------------
I

In [37]:
import torch.nn.functional as F

def train(epoch):
    netG.train()
    netD.train()
    feature_extractor.train()
    class_classifier.train()
    domain_classifier.train()

    avg_loss = 0
    avg_acc = 0
    count = 0

    data_iter = iter(mnistm_train_loader)
    len_dataloader = min(len(mnist_train_loader), len(mnistm_train_loader))

    for batch_idx in range(len_dataloader):
        # Get source data (MNIST)
        try:
            source_data, source_labels_one_hot = next(iter(mnist_train_loader))
        except StopIteration:
            data_iter = iter(mnist_train_loader)
            source_data, source_labels_one_hot = next(data_iter)

        # Get target data (MNIST-M)
        try:
            target_data, _ = next(iter(mnistm_train_loader))
        except StopIteration:
            data_iter = iter(mnistm_train_loader)
            target_data, _ = next(data_iter)

        if source_data.size(0) != opt['batchSize'] or target_data.size(0) != opt['batchSize']:
            continue  # Skip incomplete batch

        # Move data to device
        source_data = source_data.to(device)
        source_labels = torch.argmax(source_labels_one_hot, dim=1).to(device)
        target_data = target_data.to(device)

        batch_size = source_data.size(0)
        label_real = torch.full((batch_size, 1), 1.0, device=device)
        label_fake = torch.full((batch_size, 1), 0.0, device=device)

        # Generate fake images
        class_labels = torch.randint(0, 10, (batch_size,), device=device)
        one_hot_labels = F.one_hot(class_labels, num_classes=10).float()
        one_hot_labels = one_hot_labels.view(batch_size, 10, 1, 1).to(device)

        # Concatenate noise and one-hot labels
        noise = torch.randn(batch_size, opt['nz'], 1, 1, device=device)
        noise_with_labels = torch.cat((noise, one_hot_labels), 1)

        # Generate fake images
        fake_images = netG(noise_with_labels)

        # Train Discriminator
        netD.zero_grad()
        # Discriminator loss on real data
        output_real = netD(source_data)
        errD_real = adversarial_loss(output_real, label_real)
        # Discriminator loss on fake data
        output_fake = netD(fake_images.detach())
        errD_fake = adversarial_loss(output_fake, label_fake)
        # Total discriminator loss
        errD = errD_real + errD_fake
        errD.backward()
        optimizer_D.step()

        # Train Generator
        netG.zero_grad()
        output_fake = netD(fake_images)
        errG = adversarial_loss(output_fake, label_real)
        errG.backward(retain_graph=True)

        # Compute classification loss
        features = feature_extractor(fake_images)
        class_outputs = class_classifier(features)
        class_loss = cross_entropy_loss(class_outputs, class_labels)
        class_loss.backward()
        optimizer_G.step()

        # Update feature extractor and classifiers
        feature_extractor.zero_grad()
        class_classifier.zero_grad()
        domain_classifier.zero_grad()

        # Prepare domain labels
        source_domain_labels = torch.zeros(batch_size, dtype=torch.long, device=device)
        target_domain_labels = torch.ones(batch_size, dtype=torch.long, device=device)

        # Forward pass for domain classification
        features_source = feature_extractor(fake_images.detach())
        features_target = feature_extractor(target_data)
        domain_output_source = domain_classifier(features_source)
        domain_output_target = domain_classifier(features_target)
        domain_output = torch.cat((domain_output_source, domain_output_target), 0)
        domain_labels = torch.cat((source_domain_labels, target_domain_labels), 0)

        domain_loss = cross_entropy_loss(domain_output, domain_labels)
        domain_loss.backward()
        optimizer_feature.step()
        optimizer_class.step()
        optimizer_domain.step()

        # Update average loss and accuracy
        avg_loss += errG.item()
        _, predicted = torch.max(class_outputs.data, 1)
        correct = (predicted == class_labels).sum().item()
        avg_acc += correct / batch_size
        count += 1

        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch}/{opt["niter"]}] Batch [{batch_idx}/{len_dataloader}] '
                  f'Loss D: {errD.item():.4f}, Loss G: {errG.item():.4f}, '
                  f'Class Loss: {class_loss.item():.4f}, Domain Loss: {domain_loss.item():.4f}')

    avg_loss /= count
    avg_acc /= count
    return avg_acc, avg_loss

In [38]:
def test(epoch):
    netG.eval()
    feature_extractor.eval()
    class_classifier.eval()
    total = 0
    correct = 0

    with torch.no_grad():
        for batch_idx, (data, labels_one_hot) in enumerate(mnistm_test_loader):
            if data.size(0) != opt['batchSize']:
                continue  # Skip incomplete batch

            data = data.to(device)
            labels = torch.argmax(labels_one_hot, dim=1).to(device)
            features = feature_extractor(data)
            outputs = class_classifier(features)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Test Accuracy: {accuracy * 100:.2f}%')
    return accuracy

In [None]:
for epoch in range(1, opt['niter'] + 1):
    train_acc, train_loss = train(epoch)
    if epoch > train_gen_epoch:
        test_acc = test(epoch)
        #Save model checkpoints if needed


