In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm
import collections
import shutil
import torch.nn.functional as F
from torchvision import models

In [2]:
# Define device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [3]:
# Define constants
IMG_HEIGHT, IMG_WIDTH = 224, 224
BATCH_SIZE = 8
NUM_EPOCHS = 500
data_dir = 'C:\\Users\\ASUS\\Desktop\\DIFL\\22march DIFL'
shenzen_dir = os.path.join(data_dir, '0')
usa_dir = os.path.join(data_dir, '1')
MODEL_PATH_GC = os.path.join(data_dir, 'Models', 'modelGC.pth')
MODEL_PATH_G = os.path.join(data_dir, 'Models', 'modelG.pth')
MODEL_PATH_D = os.path.join(data_dir, 'Models', 'modelC.pth')

In [4]:
# Define transformations
transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.Grayscale(),  # Convert to grayscale
    transforms.Lambda(lambda x: x.convert('RGB'))  # Convert grayscale to RGB
])

In [42]:
# Define dataset class
class TBDataset(Dataset):
    def __init__(self, root_dir, label_type, distribution_type, transform=None):
        self.transform = transform
        self.label_type = label_type  # 'classification_label' or 'domain_label'
        self.distribution_type = distribution_type  # 'source' or 'target'
        
        self.folders_list = os.listdir(root_dir)
        self.Img_list = []
        for folder in self.folders_list:
            self.Img_list.extend(
                [os.path.join(root_dir, folder, path) for path in os.listdir(os.path.join(root_dir, folder))]
            )

    def __len__(self):
        return len(self.Img_list)

    def __getitem__(self, index):
        image = Image.open(self.Img_list[index])
        image = image.convert("L")
        label = None
        if self.label_type == 'classification_label':
            label = int(self.Img_list[index].split(os.path.sep)[-2])
        elif self.label_type == 'domain_label':
            if self.distribution_type == 'source':
                label = 0
            elif self.distribution_type == 'target':
                label = 1
        if self.transform is not None:
            image = self.transform(image)
        x = transforms.ToTensor()(image)
        y = np.array(label)
        y = torch.from_numpy(y)
        y = y.type(torch.LongTensor)
        return x, y

In [43]:
# Define function to create a balanced subset
def create_subset(dataset, data_limit, num_classes=2):
    data_indices = []
    per_class_data_limit = data_limit // num_classes
    target_counter = collections.Counter()
    with tqdm(total=len(dataset)) as bar:
        for idx, data in enumerate(dataset):
            if idx == len(dataset) - 1:
                for i in range(data_limit - per_class_data_limit * num_classes):
                    data_indices.append(idx)
            target = data[1].item()
            target_counter[target] += 1
            if target_counter[target] <= per_class_data_limit:
                data_indices.append(idx)
            bar.update(1)

    sub_dataset = torch.utils.data.Subset(dataset, data_indices)
    return sub_dataset

In [44]:
# Define function to split dataset in a stratified manner
def stratify_split(dataset, split):
    train_indices = []
    test_indices = []
    target_stat = collections.Counter()
    with tqdm(total=len(dataset)) as bar:
        for idx, data in enumerate(dataset):
            target = data[1].item()
            target_stat[target] += 1
        bar.update(1)

    for k in target_stat.keys():
        target_stat[k] = int(target_stat[k] * split / 100.0)
    target_counter = collections.Counter()
    with tqdm(total=len(dataset)) as bar:
        for idx, data in enumerate(dataset):
            target = data[1].item()
            target_counter[target] += 1
            if target_counter[target] <= target_stat[target]:
                train_indices.append(idx)
            else:
                test_indices.append(idx)
            bar.update(1)

    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    return train_dataset, test_dataset

In [45]:
# Define function to count number of parameters in a model
def count_num_param(model):
    return sum(p.numel() for p in model.parameters())


In [46]:
# Define function to print number of parameters with formatting
def print_param(num_param):
    num_param = str(num_param)
    ret = []
    length = len(num_param)
    for i in range(length):
        idx = i + 1
        if idx > 1 and (idx - 1) % 3 == 0:
            ret.append(',')
        ret.append(num_param[length - idx])
    ret.reverse()
    ret = ''.join(ret)
    temp_a = '## Number of parameters: ' + ret + ' ##'
    temp_b = '#' * len(temp_a)
    columns = shutil.get_terminal_size().columns
    print(f'{temp_b}'.center(columns))
    print(f'{temp_a}'.center(columns))
    print(f'{temp_b}'.center(columns))

In [47]:
class Generator(torch.nn.Module):
    def __init__(self, num_of_filters):
        super(Generator, self).__init__()
        self.num_of_filters = num_of_filters
        self.Resnet50_model = models.resnet50(pretrained=False)
        # Modify the first convolutional layer to accept 1 input channel instead of 3
        self.Resnet50_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # Remove the max pooling layer after the first convolutional block
        self.Resnet50_model.maxpool = nn.Identity()
        # Remove the final average pooling and fully connected layers
        self.Resnet50_model.fc = nn.Identity()

        # Initialize the remaining layers of the Generator
        self.A_UpConv_1 = nn.ConvTranspose2d(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.A_activation_1 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
        self.A_UpConv_2 = nn.ConvTranspose2d(1024, self.num_of_filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.A_activation_2 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
        self.A_conv = nn.Conv2d(self.num_of_filters, self.num_of_filters, kernel_size=3, stride=1, padding=1, bias=False)
        self.A_activation_3 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)

        self.B_UpConv_1 = nn.ConvTranspose2d(self.num_of_filters, self.num_of_filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.B_activation_1 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
        self.B_conv = nn.Conv2d(self.num_of_filters, self.num_of_filters, kernel_size=3, stride=1, padding=1, bias=False)
        self.B_activation_3 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)

        self.C_UpConv_1 = nn.ConvTranspose2d(self.num_of_filters, self.num_of_filters, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.C_activation_1 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)

    def forward(self, x):
        # Ensure that the input images are correctly resized and processed as grayscale
        x = F.interpolate(x.unsqueeze(1), size=(224, 224))  # Add channel dimension and interpolate
        # Forward pass through the modified ResNet50 model
        x = self.Resnet50_model(x)
        # Reshape the output to include spatial dimensions
        x = x.view(x.size(0), -1, 1, 1)  # Assume the spatial dimensions are 1x1
        # Check the shape of the output after reshaping
        print("Shape after ResNet50 and reshaping:", x.shape)
        # Forward pass through the remaining layers of the Generator
        x = self.A_UpConv_1(x)
        x = self.A_activation_1(x)
        x = self.A_UpConv_2(x)
        x = self.A_activation_2(x)
        x = self.A_conv(x)
        x = self.A_activation_3(x)
        x = self.B_UpConv_1(x)
        x = self.B_activation_1(x)
        x = self.B_conv(x)
        x = self.B_activation_3(x)
        x = self.C_UpConv_1(x)
        x = self.C_activation_1(x)
        return x



In [48]:
# Define the Discriminator class
class Discriminator(torch.nn.Module):
    def __init__(self, num_of_filters):
        super(Discriminator, self).__init__()
        self.num_of_filters = num_of_filters
        self.vgg19_model = models.vgg19(pretrained=False)
        self.vgg19_model.features[0] = nn.Conv2d(512, 64, 3, 1, 1)
        self.vgg19_model = nn.Sequential(*list(self.vgg19_model.children())[:-2])
        self.Linear_1 = nn.Linear(self.num_of_filters, 256)
        self.Activation_1 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
        self.Linear_2 = nn.Linear(256, 128)
        self.Activation_2 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
        self.Linear_3 = nn.Linear(128, 64)
        self.Activation_3 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
        self.Linear_final = nn.Linear(64, 2)
        self.Activation_final = nn.Sigmoid()

    def forward(self, x):
        x = self.vgg19_model(x)
        x = nn.Flatten()(x)
        x = self.Linear_1(x)
        x = self.Activation_1(x)
        x = self.Linear_2(x)
        x = self.Activation_2(x)
        x = self.Linear_3(x)
        x = self.Activation_3(x)
        x = self.Linear_final(x)
        x = self.Activation_final(x)
        return x


# Classifier

In [49]:
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class Classifier(torch.nn.Module):
    def __init__(self, num_of_filters):
        super(Classifier, self).__init__()
        self.num_of_filters = num_of_filters
        self.vgg19_model = models.vgg19(pretrained=False)
        self.vgg19_model.features[0] = nn.Conv2d(512, 64, 3, 1, 1)
        self.vgg19_model = nn.Sequential(*list(self.vgg19_model.children())[:-1])  # Remove last pooling layer
        self.adaptive_pooling = nn.AdaptiveAvgPool2d((1, 1))  # Add adaptive pooling
        self.Linear_1 = nn.Linear(self.num_of_filters, 256)
        self.Activation_1 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
        self.Linear_2 = nn.Linear(256, 128)
        self.Activation_2 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
        self.Linear_3 = nn.Linear(128, 64)
        self.Activation_3 = torch.nn.LeakyReLU(negative_slope=0.2, inplace=False)
        self.Linear_final = nn.Linear(64, 2)
        self.Activation_final = nn.Sigmoid()

    def forward(self, x):
        x = self.vgg19_model(x)
        print("Shape after VGG19 model:", x.shape)
        # Apply adaptive pooling
        x = self.adaptive_pooling(x)
        print("Shape after adaptive pooling:", x.shape)
        x = nn.Flatten()(x)
        print("Shape after flattening:", x.shape)
        x = self.Linear_1(x)
        print("Shape after Linear_1:", x.shape)
        x = self.Activation_1(x)
        print("Shape after Activation_1:", x.shape)
        x = self.Linear_2(x)
        print("Shape after Linear_2:", x.shape)
        x = self.Activation_2(x)
        print("Shape after Activation_2:", x.shape)
        x = self.Linear_3(x)
        print("Shape after Linear_3:", x.shape)
        x = self.Activation_3(x)
        print("Shape after Activation_3:", x.shape)
        x = self.Linear_final(x)
        print("Shape after Linear_final:", x.shape)
        x = self.Activation_final(x)
        print("Shape after Activation_final:", x.shape)
        return x


In [50]:
# Initialize models
G = Generator(512).to(device)
D = Discriminator(512).to(device)
C = Classifier(512).to(device)

In [51]:
# Print number of parameters for each model
print_param(count_num_param(G))
print_param(count_num_param(D))
print_param(count_num_param(C))


                     ######################################                     
                     ## Number of parameters: 56,534,464 ##                     
                     ######################################                     
                     ######################################                     
                     ## Number of parameters: 20,490,178 ##                     
                     ######################################                     
                     ######################################                     
                     ## Number of parameters: 20,490,178 ##                     
                     ######################################                     


In [52]:
# Define BCE loss
criterion = nn.BCELoss()

In [53]:
# Define optimizers
Params_Classification = list(G.parameters()) + list(C.parameters())
Optimizer_Classification = optim.Adam(Params_Classification, lr=0.0002, betas=(0.5, 0.999))
Optimizer_DomainInvariance_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
Optimizer_DomainInvariance_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [54]:

# Define training function
def training(diseases_classification_trainset, domain_classification_trainset,
             MODEL_PATH_GC, MODEL_PATH_G, MODEL_PATH_D, device, num_epochs):
    G.to(device)
    C.to(device)
    D.to(device)
    G.eval()
    C.eval()
    D.eval()
    for epoch in range(num_epochs):
        for batch_number, example in enumerate(diseases_classification_trainset):
            X_s, y_s = example
            X_s, y_s = X_s.to(device), y_s.to(device)
            G_X_s = G(X_s)
            y_hat_s = C(G_X_s)
            l_c = criterion(y_hat_s, y_s)
            Optimizer_Classification.zero_grad()
            l_c.backward()
            Optimizer_Classification.step()
        torch.save({
                'model_state_dict': C.state_dict(),
                'optim_state_dict': Optimizer_Classification.state_dict(),
                'epoch': epoch
        }, MODEL_PATH_GC)
        for batch_number, example in enumerate(domain_classification_trainset):
            X, d = example
            X, d = X.to(device), d.to(device)
            G_X = G(X)
            d_hat = D(G_X)
            l_d = criterion(d_hat, d)
            d_gen = torch.full(d.shape, 0.5).to(device)
            l_g = criterion(d_hat, d_gen)

            Optimizer_DomainInvariance_G.zero_grad()
            Optimizer_DomainInvariance_D.zero_grad()
            l_d.backward()
            l_g.backward()
            Optimizer_DomainInvariance_G.step()
            Optimizer_DomainInvariance_D.step()
        torch.save({
                'model_state_dict': G.state_dict(),
                'optim_state_dict': Optimizer_DomainInvariance_G.state_dict(),
                'epoch': epoch
        }, MODEL_PATH_G)
        torch.save({
                'model_state_dict': D.state_dict(),
                'optim_state_dict': Optimizer_DomainInvariance_D.state_dict(),
                'epoch': epoch
        }, MODEL_PATH_D)


In [55]:
# Define evaluation function
def evaluate(dataset_loader, model, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataset_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [56]:
# Define datasets
shenzen_dataset = TBDataset(shenzen_dir, label_type='classification_label', distribution_type='source', transform=transform)
usa_dataset = TBDataset(usa_dir, label_type='classification_label', distribution_type='target', transform=transform)


In [57]:
# Create balanced subsets
shenzen_balanced = create_subset(shenzen_dataset, 2000)
usa_balanced = create_subset(usa_dataset, 2000)


100%|██████████| 662/662 [01:05<00:00, 10.03it/s]
100%|██████████| 138/138 [00:35<00:00,  3.93it/s]


In [58]:
# Split datasets
shenzen_trainset, shenzen_testset = stratify_split(shenzen_balanced, split=80)
usa_trainset, usa_testset = stratify_split(usa_balanced, split=80)

  0%|          | 1/662 [01:03<11:38:37, 63.42s/it]
100%|██████████| 662/662 [01:05<00:00, 10.18it/s]
  1%|          | 1/138 [00:32<1:13:58, 32.40s/it]
100%|██████████| 138/138 [00:33<00:00,  4.17it/s]


In [None]:
# Instantiate your classifier
classifier = Classifier(num_of_filters=512)

# Check the parameters of the max pooling layers
for module in classifier.vgg19_model:
    if isinstance(module, nn.MaxPool2d):
        print("Kernel size:", module.kernel_size)
        print("Stride:", module.stride)

# Continue with the rest of your code, such as training and evaluation
training(shenzen_trainset, usa_trainset, MODEL_PATH_GC, MODEL_PATH_G, MODEL_PATH_D, device, NUM_EPOCHS)

In [None]:
# Train models
training(shenzen_trainset, usa_trainset, MODEL_PATH_GC, MODEL_PATH_G, MODEL_PATH_D, device, NUM_EPOCHS)


In [None]:
# Load the trained models
C.load_state_dict(torch.load(MODEL_PATH_GC)['model_state_dict'])
G.load_state_dict(torch.load(MODEL_PATH_G)['model_state_dict'])
D.load_state_dict(torch.load(MODEL_PATH_D)['model_state_dict'])

In [62]:
# Evaluate models
shenzen_testloader = torch.utils.data.DataLoader(shenzen_testset, batch_size=BATCH_SIZE, shuffle=True)
usa_testloader = torch.utils.data.DataLoader(usa_testset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
shenzen_accuracy = evaluate(shenzen_testloader, C, device)
usa_accuracy = evaluate(usa_testloader, C, device)

In [None]:
# Print results
print("Accuracy on Shenzen dataset:", shenzen_accuracy)
print("Accuracy on USA dataset:", usa_accuracy)
