In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/BXT/subset/

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/BXT/subset


In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os

In [None]:
# CSV Files Path and Parameters
source_path = "source_subset.csv"
target_path = "target_subset.csv"
batch_size = 16

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class SourceTargetDataset(Dataset):
    def __init__(self, source_csv_path, target_csv_path):

        # Load source data (features and labels) with explicit dtypes
        feature_dtypes = {f'dim_{i}': np.float32 for i in range(1024)}
        feature_dtypes['label'] = np.float32  # Add label dtype

        # Load source data with explicit type conversion
        self.source_df = pd.read_csv(source_csv_path)

        # Load target data (features only)
        self.target_df = pd.read_csv(target_csv_path)

        # Extract feature column names (dim_0 to dim_1023)
        self.feature_cols = [f'dim_{i}' for i in range(1024)]

        # Check if all feature columns exist in both dataframes
        assert all(col in self.source_df.columns for col in self.feature_cols), "Source CSV missing some feature columns"
        assert all(col in self.target_df.columns for col in self.feature_cols), "Target CSV missing some feature columns"

        # Check if label column exists in source dataframe
        assert 'label' in self.source_df.columns, "Source CSV missing label column"

        # Explicitly convert feature columns to float32
        for col in self.feature_cols:
            self.source_df[col] = self.source_df[col].astype(np.float32)
            self.target_df[col] = self.target_df[col].astype(np.float32)

        # Convert label column to float32
        self.source_df['label'] = self.source_df['label'].astype(np.float32)

        # Get dataset length (use minimum length if they differ)
        self.length = min(len(self.source_df), len(self.target_df))

        print(f"Loaded {self.length} samples")

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        # Get source features (shape: 1024)
        source_features = torch.tensor(
            self.source_df.iloc[idx][self.feature_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        # Get source label
        source_label = torch.tensor(
            self.source_df.iloc[idx]['label'].astype(np.float32),
            dtype=torch.float32
        )

        # Get target features (shape: 1024)
        target_features = torch.tensor(
            self.target_df.iloc[idx][self.feature_cols].values.astype(np.float32),
            dtype=torch.float32
        )

        return {
            'source_features': source_features,  # Shape: 1024
            'source_label': source_label,        # Shape: 1
            'target_features': target_features   # Shape: 1024

            }

def get_data_loaders(source_csv_path, target_csv_path, batch_size=32, shuffle=True, num_workers=2):

    dataset = SourceTargetDataset(source_csv_path, target_csv_path)

    data_loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True
    )

    return data_loader


In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self, input_size):
        super(FeatureExtractor, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, 768),
            nn.LeakyReLU(0.2),
            nn.Linear(768, 512),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.model(x)

class Classifier(nn.Module):
    def __init__(self, input_size):
        super(Classifier, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x).view(-1)

class Discriminator(nn.Module):
    def __init__(self, input_size):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x).view(-1)

In [None]:
class DomainAdaptation():

    def __init__(self, source_csv_path, target_csv_path, batch_size=16):

        self.dataloader = get_data_loaders(source_path, target_path, batch_size=batch_size)

        self.G = FeatureExtractor(input_size=1024).to(device)               # Generator
        self.C = Classifier(input_size=512).to(device)                      # Classifier
        self.C1 = Classifier(input_size=512).to(device)                     # Classifier-1
        self.C2 = Classifier(input_size=512).to(device)                     # Classifier-2
        self.D = Discriminator(input_size=512).to(device)                   # Discriminator

        self.batch_size = batch_size
        self.lr = 1e-3
        self.epochs = 20

        self.set_optimizer()                # Setting Up Adam Optimizer
        self.reset_grad()

        # Saving Pseudo-Labels of Target Domain Batch
        self.output_cr_t_C_label = np.zeros(self.batch_size)

        self.checkpoint_dir = 'sample_data'            # Directory Path where the weights are being saved

        # Create checkpoint directory if it doesn't exist
        os.makedirs(self.checkpoint_dir, exist_ok=True)


    def set_optimizer(self):
        self.opt_g = optim.Adam(self.G.parameters(), lr=self.lr, weight_decay=0.0005)
        self.opt_c = optim.Adam(self.C.parameters(), lr=self.lr, weight_decay=0.0005)
        self.opt_c1 = optim.Adam(self.C1.parameters(), lr=self.lr, weight_decay=0.0005)
        self.opt_c2 = optim.Adam(self.C2.parameters(), lr=self.lr, weight_decay=0.0005)
        self.opt_d = optim.Adam(self.D.parameters(), lr=self.lr, weight_decay=0.0005)


    def reset_grad(self):
        self.opt_g.zero_grad()
        self.opt_c.zero_grad()
        self.opt_c1.zero_grad()
        self.opt_c2.zero_grad()
        self.opt_d.zero_grad()


    def discrepancy(self, out1, out2):
        # Ensure tensors have shape [batch_size, num_classes] for softmax
        out1_reshaped = out1.view(-1, 1)
        out2_reshaped = out2.view(-1, 1)
        # Concatenate with 1-values to create 2-class outputs for softmax
        out1_2class = torch.cat([1-out1_reshaped, out1_reshaped], dim=1)
        out2_2class = torch.cat([1-out2_reshaped, out2_reshaped], dim=1)
        return torch.mean(torch.abs(F.softmax(out1_2class, dim=1) - F.softmax(out2_2class, dim=1)))


    def linear_mmd(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X - f_of_Y
        loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
        return loss


    def ent(self, output):
        # Ensure output has shape [batch_size, num_classes] for softmax
        output_reshaped = output.view(-1, 1)
        # Concatenate with 1-values to create 2-class outputs for softmax
        output_2class = torch.cat([1-output_reshaped, output_reshaped], dim=1)
        return torch.mean(F.softmax(output_2class + 1e-6, dim=1) * torch.log(F.softmax(output_2class + 1e-6, dim=1))).negative()


    def save_weights(self, epoch):

        epoch_dir = os.path.join(self.checkpoint_dir, f'epoch_{epoch}')
        os.makedirs(epoch_dir, exist_ok=True)

        # Save each model's state dict
        torch.save(self.G.state_dict(), os.path.join(epoch_dir, 'feature_extractor.pth'))
        torch.save(self.C.state_dict(), os.path.join(epoch_dir, 'classifier.pth'))
        torch.save(self.C1.state_dict(), os.path.join(epoch_dir, 'classifier1.pth'))
        torch.save(self.C2.state_dict(), os.path.join(epoch_dir, 'classifier2.pth'))
        torch.save(self.D.state_dict(), os.path.join(epoch_dir, 'discriminator.pth'))

        # Save optimizer states
        torch.save(self.opt_g.state_dict(), os.path.join(epoch_dir, 'opt_g.pth'))
        torch.save(self.opt_c.state_dict(), os.path.join(epoch_dir, 'opt_c.pth'))
        torch.save(self.opt_c1.state_dict(), os.path.join(epoch_dir, 'opt_c1.pth'))
        torch.save(self.opt_c2.state_dict(), os.path.join(epoch_dir, 'opt_c2.pth'))
        torch.save(self.opt_d.state_dict(), os.path.join(epoch_dir, 'opt_d.pth'))

        # Save a single checkpoint with all models (for convenience)
        checkpoint = {
            'epoch': epoch,
            'G_state_dict': self.G.state_dict(),
            'C_state_dict': self.C.state_dict(),
            'C1_state_dict': self.C1.state_dict(),
            'C2_state_dict': self.C2.state_dict(),
            'D_state_dict': self.D.state_dict(),
            'opt_g_state_dict': self.opt_g.state_dict(),
            'opt_c_state_dict': self.opt_c.state_dict(),
            'opt_c1_state_dict': self.opt_c1.state_dict(),
            'opt_c2_state_dict': self.opt_c2.state_dict(),
            'opt_d_state_dict': self.opt_d.state_dict(),
        }
        torch.save(checkpoint, os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch}.pth'))

        print(f"Model weights saved for epoch {epoch}")


    def load_checkpoint(self, epoch):

        checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
        if os.path.exists(checkpoint_path):
            checkpoint = torch.load(checkpoint_path)

            self.G.load_state_dict(checkpoint['G_state_dict'])
            self.C.load_state_dict(checkpoint['C_state_dict'])
            self.C1.load_state_dict(checkpoint['C1_state_dict'])
            self.C2.load_state_dict(checkpoint['C2_state_dict'])
            self.D.load_state_dict(checkpoint['D_state_dict'])

            self.opt_g.load_state_dict(checkpoint['opt_g_state_dict'])
            self.opt_c.load_state_dict(checkpoint['opt_c_state_dict'])
            self.opt_c1.load_state_dict(checkpoint['opt_c1_state_dict'])
            self.opt_c2.load_state_dict(checkpoint['opt_c2_state_dict'])
            self.opt_d.load_state_dict(checkpoint['opt_d_state_dict'])

            print(f"Loaded checkpoint from epoch {epoch}")
            return epoch
        else:
            print(f"No checkpoint found for epoch {epoch}")
            return 0


    def train(self):
        A_st_min = 0
        A_st_max = 1
        min_J_w = 0
        max_J_w = 1
        A_st_norm = 0.5
        J_w_norm = 0.5
        criterion = nn.BCELoss()

        for epoch in range(self.epochs):
            self.G.train()
            self.C.train()
            self.C1.train()
            self.C2.train()
            self.D.train()
            fea_for_LDA = np.empty(shape=(0, 512))
            fea_s_for_LDA = np.empty(shape=(0, 512))
            label_for_LDA = np.empty(shape=(0, 1))
            label_s_for_LDA = []

            # Training loop
            for batch_idx, batch in enumerate(tqdm(self.dataloader, desc=f"Epoch {epoch+1}/{self.epochs}")):
                # Access source features (shape: B x 1024)
                source_features = batch['source_features']

                # Access source labels (shape: B)
                source_labels = batch['source_label']

                # Access target features (shape: B x 1024)
                target_features = batch['target_features']

                # Ensure correct batch size (handle last batch which might be smaller)
                actual_batch_size = source_features.size(0)
                if actual_batch_size != self.batch_size:
                    # Resize output_cr_t_C_label if needed
                    if len(self.output_cr_t_C_label) != actual_batch_size:
                        self.output_cr_t_C_label = np.zeros(actual_batch_size)

                # Sending to GPU if available
                source_features = source_features.to(device)
                source_labels = source_labels.to(device)
                target_features = target_features.to(device)

                self.reset_grad()

                # Computing T
                T_complex = A_st_norm /(A_st_norm + (1.0 - J_w_norm))
                T = T_complex.real

                # Normal Supervised Learning on Source Domain Batch
                for i in range(8):
                    feat_cr_s = self.G(source_features)
                    output_cr_s_C = self.C(feat_cr_s)
                    loss_1 = criterion(output_cr_s_C, source_labels)

                    loss_1.backward()
                    self.opt_g.step()               # Update G
                    self.opt_c.step()               # Update C
                    self.reset_grad()

                # Transferability
                for i in range(1):
                    feat_cr_s = self.G(source_features)
                    feat_cr_t = self.G(target_features)

                    # Training Discriminator of GAN
                    output_cr_s_D = self.D(feat_cr_s)
                    output_cr_t_D = self.D(feat_cr_t)
                    loss_2 = criterion(output_cr_s_D, output_cr_t_D.detach())
                    loss_2 = 0.1 * loss_2
                    loss_2.backward()
                    self.opt_d.step()               # Update D
                    self.reset_grad()

                    # Training All Classifiers
                    feat_cr_s = self.G(source_features)  # Recompute features since graph was freed
                    feat_cr_t = self.G(target_features)  # Recompute features since graph was freed

                    output_cr_s_C = self.C(feat_cr_s)
                    output_cr_t_C = self.C(feat_cr_t)
                    output_cr_s_C1 = self.C1(feat_cr_s)
                    output_cr_s_C2 = self.C2(feat_cr_s)
                    output_cr_t_C1 = self.C1(feat_cr_t)
                    output_cr_t_C2 = self.C2(feat_cr_t)

                    source_labels_reshaped = source_labels.view(-1, 1)
                    loss_cr_s = criterion(output_cr_s_C1, source_labels) + criterion(output_cr_s_C2, source_labels) + criterion(output_cr_s_C, source_labels)
                    loss_dis1_t = self.discrepancy(output_cr_t_C1, output_cr_t_C2).negative()       # Negative
                    loss_3 = loss_cr_s + loss_dis1_t
                    loss_3.backward()

                    self.opt_c1.step()                # Update all 3 classifiers
                    self.opt_c2.step()
                    self.opt_c.step()
                    self.reset_grad()

                # Balance of transferability and discriminability
                for i in range(8):
                    feat_cr_s = self.G(source_features)
                    feat_cr_t = self.G(target_features)

                    output_cr_s_D = self.D(feat_cr_s)
                    output_cr_t_D = self.D(feat_cr_t)
                    loss_4 = criterion(output_cr_s_D, output_cr_t_D.detach()).negative()           # Negative
                    loss_4 = 0.2 * loss_4

                    # Recompute target features for classifier outputs
                    feat_cr_t = self.G(target_features)
                    output_cr_t_C = self.C(feat_cr_t)
                    output_cr_t_C1 = self.C1(feat_cr_t)
                    output_cr_t_C2 = self.C2(feat_cr_t)

                    # Calculating class discrimination loss (L_cd)
                    loss_51 = self.discrepancy(output_cr_t_C1, output_cr_t_C2)
                    loss_52 = self.discrepancy(output_cr_t_C, output_cr_t_C1)
                    loss_53 = self.discrepancy(output_cr_t_C, output_cr_t_C2)
                    loss_5 = loss_51 + loss_52 + loss_53

                    loss_all = T*loss_4 + (1.0-T)*loss_5
                    loss_all.backward()
                    self.opt_g.step()           # Update G
                    self.reset_grad()

                # Re-weighting based on the uncertainty of pseudo-label
                for i in range(1):
                    feat_cr_t = self.G(target_features)
                    output_cr_t_C = self.C(feat_cr_t)
                    output_cr_t_C_de = output_cr_t_C.detach()

                    # Update output_cr_t_C_label with correct batch size handling
                    for ii in range(actual_batch_size):
                        self.output_cr_t_C_label[ii] = output_cr_t_C_de[ii].item()

                    output_cr_t_C_labels = torch.from_numpy(self.output_cr_t_C_label[:actual_batch_size]).to(device).float()
                    Ly_ce_t = criterion(output_cr_t_C, output_cr_t_C_labels)

                    # Ensure ent method is correctly defined and uses dim=1 for softmax
                    H_emp = self.ent(output_cr_t_C)
                    mu = (torch.exp(-H_emp)-1.0/2)/(1-1.0/2)
                    Ly_loss = 2*(mu*Ly_ce_t+(1-mu)*H_emp)

                    Ly_loss.backward()
                    self.opt_g.step()               # Update G
                    self.opt_c.step()               # Update C
                    self.reset_grad()

                # Data for A_st(MMD) and J_w
                with torch.no_grad():  # No need to track gradients for these computations
                    # for source
                    feat_cr_s = self.G(source_features)
                    feat_cr_t = self.G(target_features)
                    label_predi = self.C(feat_cr_t)           # Pseudo-Labels for Target Domain Batch

                    # Convert to numpy for later calculations
                    feat_s_test_np = feat_cr_s.cpu().detach().numpy()
                    label_s_test_np = source_labels.cpu().detach().numpy()

                    # Append to our collected data
                    label_s_for_LDA = np.append(label_s_for_LDA, label_s_test_np.flatten())
                    fea_s_for_LDA = np.vstack((fea_s_for_LDA, feat_s_test_np)) if fea_s_for_LDA.size > 0 else feat_s_test_np

                    # for target
                    feat_test_np = feat_cr_t.cpu().detach().numpy()
                    fea_for_LDA = np.vstack((fea_for_LDA, feat_test_np)) if fea_for_LDA.size > 0 else feat_test_np

                    # Pseudo-Labels for Target Domain
                    label_t = (label_predi >= 0.5).float()  # Binary classification threshold
                    label_test_np = label_t.cpu().detach().numpy()
                    label_test_np = label_test_np.reshape(-1, 1)  # Reshape to ensure correct dimensions

                    label_for_LDA = np.vstack((label_for_LDA, label_test_np)) if label_for_LDA.size > 0 else label_test_np

            # Calculate source domain accuracy after each epoch
            source_accuracy = self.calculate_source_accuracy(self.G, self.C)
            print(f"Epoch {epoch+1}/{self.epochs} - Source Domain Accuracy: {source_accuracy:.2f}%")

            # Save model weights after each epoch
            self.save_weights(epoch+1)

            try:
                # MMD with Norm (A_st_norm)
                f_of_X = torch.from_numpy(fea_s_for_LDA).float().to(device)
                f_of_Y = torch.from_numpy(fea_for_LDA).float().to(device)
                loss_mmd = self.linear_mmd(f_of_X, f_of_Y)
                A_st = loss_mmd.cpu().detach().numpy()
                A_st_max = max(abs(A_st_max), abs(A_st))
                A_st_min = min(abs(A_st_min), abs(A_st))
                A_st_norm = abs(A_st-A_st_min)/(A_st_max-A_st_min+1e-6)

                # J_w calculation
                self.class_num = 2  # Binary classification has 2 classes

                # J_w_s with Norm
                n_dim = 1
                clusters1 = [0, 1]
                label_s_for_LDA = np.array(label_s_for_LDA).reshape(-1)

                Sw1 = np.zeros((fea_s_for_LDA.shape[1], fea_s_for_LDA.shape[1]))
                for i in clusters1:
                    class_indices = (label_s_for_LDA == i)
                    if np.sum(class_indices) > 1:  # Ensure we have at least 2 samples for this class
                        datai1 = fea_s_for_LDA[class_indices]
                        datai1 = datai1 - datai1.mean(0)
                        Swi1 = np.asmatrix(datai1).T * np.asmatrix(datai1)
                        Sw1 += Swi1

                # Between-class scatter matrix
                SB1 = np.zeros((fea_s_for_LDA.shape[1], fea_s_for_LDA.shape[1]))
                u1 = fea_s_for_LDA.mean(0)  # Average of all samples
                for i in clusters1:
                    class_indices = (label_s_for_LDA == i)
                    if np.sum(class_indices) > 0:  # Ensure we have samples for this class
                        Ni1 = fea_s_for_LDA[class_indices].shape[0]
                        ui1 = fea_s_for_LDA[class_indices].mean(0)  # Average of a category
                        SBi1 = Ni1 * np.asmatrix(ui1 - u1).T * np.asmatrix(ui1 - u1)
                        SB1 += SBi1

                # Add small regularization to ensure invertibility
                S1 = np.linalg.inv(Sw1 + (1e-6 * np.eye(Sw1.shape[0]))) * SB1
                eigVals1, eigVects1 = np.linalg.eig(S1)  # Find eigenvalues, eigenvectors
                eigValInd1 = np.argsort(eigVals1)
                eigValInd1 = eigValInd1[:(-n_dim-1):-1]
                J_max1 = 0
                for i in range(min(n_dim, len(eigValInd1))):
                    J_max1 = J_max1 + np.real(eigVals1[eigValInd1[i]])  # Ensure we use real part

                J_w_s = J_max1/self.class_num
                max_J_w = max(max_J_w, J_w_s)
                min_J_w = min(min_J_w, J_w_s)

                # J_w_t calculation
                n_dim = 1
                label_for_LDA_flat = label_for_LDA.reshape(-1)
                clusters = np.unique(label_for_LDA_flat)

                Sw = np.zeros((fea_for_LDA.shape[1], fea_for_LDA.shape[1]))
                for i in clusters:
                    class_indices = (label_for_LDA_flat == i)
                    if np.sum(class_indices) > 1:  # Ensure we have at least 2 samples for this class
                        datai = fea_for_LDA[class_indices]
                        datai = datai - datai.mean(0)
                        Swi = np.asmatrix(datai).T * np.asmatrix(datai)
                        Sw += Swi

                # Between-class scatter matrix
                SB = np.zeros((fea_for_LDA.shape[1], fea_for_LDA.shape[1]))
                u = fea_for_LDA.mean(0)  # Average of all samples
                for i in clusters:
                    class_indices = (label_for_LDA_flat == i)
                    if np.sum(class_indices) > 0:  # Ensure we have samples for this class
                        Ni = fea_for_LDA[class_indices].shape[0]
                        ui = fea_for_LDA[class_indices].mean(0)  # Average of a category
                        SBi = Ni * np.asmatrix(ui - u).T * np.asmatrix(ui - u)
                        SB += SBi

                # Add small regularization to ensure invertibility
                S = np.linalg.inv(Sw + (1e-6 * np.eye(Sw.shape[0]))) * SB
                eigVals, eigVects = np.linalg.eig(S)  # Find eigenvalues, eigenvectors
                eigValInd = np.argsort(eigVals)
                eigValInd = eigValInd[:(-n_dim-1):-1]
                J_max = 0
                for i in range(min(n_dim, len(eigValInd))):
                    J_max = J_max + np.real(eigVals[eigValInd[i]])  # Ensure we use real part

                J_w_t = J_max/self.class_num
                min_J_w = min(min_J_w, J_w_t)
                max_J_w = max(max_J_w, J_w_t)

                # J_w_s_norm
                J_w = min(J_w_s, J_w_t)
                J_w_norm = (J_w - min_J_w)/(max_J_w-min_J_w+1e-6)

                print(f"Epoch {epoch+1} - A_st_norm: {A_st_norm:.4f}, J_w_norm: {J_w_norm:.4f}")

            except Exception as e:
                print(f"Warning: Could not calculate A_st_norm and J_w_norm for epoch {epoch+1}. Error: {e}")
                # Keep previous values
                print(f"Using previous values: A_st_norm: {A_st_norm:.4f}, J_w_norm: {J_w_norm:.4f}")

    def calculate_source_accuracy(self, model_G, model_C):
        """
        Calculate accuracy on the source domain

        Args:
            model_G: Feature extractor model
            model_C: Classifier model
            dataloader: DataLoader containing source and target data
            device: Device to run inference on (cpu or cuda)

        Returns:
            float: Accuracy on source domain
        """
        model_G.eval()
        model_C.eval()

        correct = 0
        total = 0

        with torch.no_grad():
            for batch in self.dataloader:
                source_features = batch['source_features'].to(device)
                source_labels = batch['source_label'].to(device)

                # Extract features and classify
                features = model_G(source_features)
                outputs = model_C(features)

                # Convert sigmoid outputs to binary predictions (threshold at 0.5)
                predicted = (outputs >= 0.5).float()

                # Update statistics
                total += source_labels.size(0)
                correct += (predicted == source_labels).sum().item()

        accuracy = 100 * correct / total

        # Reset models back to training mode
        model_G.train()
        model_C.train()

        return accuracy

In [None]:
model = DomainAdaptation(source_path, target_path, batch_size=16)
model.train()

Loaded 1024 samples


Epoch 1/20: 100%|██████████| 64/64 [00:10<00:00,  5.94it/s]


Epoch 1/20 - Source Domain Accuracy: 80.66%
Model weights saved for epoch 1
Epoch 1 - A_st_norm: 0.3043, J_w_norm: 0.8697


Epoch 2/20: 100%|██████████| 64/64 [00:09<00:00,  6.70it/s]


Epoch 2/20 - Source Domain Accuracy: 75.29%
Model weights saved for epoch 2
Epoch 2 - A_st_norm: 1.0000, J_w_norm: 0.1244


Epoch 3/20: 100%|██████████| 64/64 [00:10<00:00,  6.33it/s]


Epoch 3/20 - Source Domain Accuracy: 80.76%
Model weights saved for epoch 3
Epoch 3 - A_st_norm: 1.0000, J_w_norm: 0.1327


Epoch 4/20: 100%|██████████| 64/64 [00:10<00:00,  5.86it/s]


Epoch 4/20 - Source Domain Accuracy: 82.71%
Model weights saved for epoch 4
Epoch 4 - A_st_norm: 1.0000, J_w_norm: 0.1043


Epoch 5/20: 100%|██████████| 64/64 [00:10<00:00,  5.82it/s]


Epoch 5/20 - Source Domain Accuracy: 87.40%
Model weights saved for epoch 5
Epoch 5 - A_st_norm: 0.7715, J_w_norm: 0.6392


Epoch 6/20: 100%|██████████| 64/64 [00:10<00:00,  5.86it/s]


Epoch 6/20 - Source Domain Accuracy: 85.35%
Model weights saved for epoch 6
Epoch 6 - A_st_norm: 0.3971, J_w_norm: 0.0462


Epoch 7/20: 100%|██████████| 64/64 [00:09<00:00,  6.89it/s]


Epoch 7/20 - Source Domain Accuracy: 91.11%
Model weights saved for epoch 7
Epoch 7 - A_st_norm: 0.2861, J_w_norm: 0.0412


Epoch 8/20: 100%|██████████| 64/64 [00:10<00:00,  6.26it/s]


Epoch 8/20 - Source Domain Accuracy: 92.68%
Model weights saved for epoch 8
Epoch 8 - A_st_norm: 0.6625, J_w_norm: 0.0380


Epoch 9/20: 100%|██████████| 64/64 [00:11<00:00,  5.76it/s]


Epoch 9/20 - Source Domain Accuracy: 96.29%
Model weights saved for epoch 9
Epoch 9 - A_st_norm: 0.2286, J_w_norm: 0.1223


Epoch 10/20: 100%|██████████| 64/64 [00:11<00:00,  5.76it/s]


Epoch 10/20 - Source Domain Accuracy: 87.89%
Model weights saved for epoch 10
Epoch 10 - A_st_norm: 0.1332, J_w_norm: 0.1768


Epoch 11/20: 100%|██████████| 64/64 [00:11<00:00,  5.75it/s]


Epoch 11/20 - Source Domain Accuracy: 96.68%
Model weights saved for epoch 11
Epoch 11 - A_st_norm: 0.1132, J_w_norm: 0.0805


Epoch 12/20: 100%|██████████| 64/64 [00:09<00:00,  6.84it/s]


Epoch 12/20 - Source Domain Accuracy: 95.90%
Model weights saved for epoch 12
Epoch 12 - A_st_norm: 0.0358, J_w_norm: 0.0339


Epoch 13/20: 100%|██████████| 64/64 [00:10<00:00,  6.01it/s]


Epoch 13/20 - Source Domain Accuracy: 92.97%
Model weights saved for epoch 13
Epoch 13 - A_st_norm: 0.1419, J_w_norm: 0.0520


Epoch 14/20: 100%|██████████| 64/64 [00:11<00:00,  5.75it/s]


Epoch 14/20 - Source Domain Accuracy: 98.05%
Model weights saved for epoch 14
Epoch 14 - A_st_norm: 0.1735, J_w_norm: 0.0194


Epoch 15/20: 100%|██████████| 64/64 [00:11<00:00,  5.80it/s]


Epoch 15/20 - Source Domain Accuracy: 96.78%
Model weights saved for epoch 15
Epoch 15 - A_st_norm: 0.0480, J_w_norm: 0.0352


Epoch 16/20: 100%|██████████| 64/64 [00:11<00:00,  5.73it/s]


Epoch 16/20 - Source Domain Accuracy: 91.11%
Model weights saved for epoch 16
Epoch 16 - A_st_norm: 0.0324, J_w_norm: 0.0233


Epoch 17/20: 100%|██████████| 64/64 [00:09<00:00,  6.85it/s]


Epoch 17/20 - Source Domain Accuracy: 98.14%
Model weights saved for epoch 17
Epoch 17 - A_st_norm: 0.0770, J_w_norm: 0.0526


Epoch 18/20: 100%|██████████| 64/64 [00:10<00:00,  6.31it/s]


Epoch 18/20 - Source Domain Accuracy: 89.75%
Model weights saved for epoch 18
Epoch 18 - A_st_norm: 0.0222, J_w_norm: 0.0197


Epoch 19/20: 100%|██████████| 64/64 [00:11<00:00,  5.76it/s]


Epoch 19/20 - Source Domain Accuracy: 96.19%
Model weights saved for epoch 19
Epoch 19 - A_st_norm: 0.0356, J_w_norm: 0.0145


Epoch 20/20: 100%|██████████| 64/64 [00:10<00:00,  5.82it/s]


Epoch 20/20 - Source Domain Accuracy: 97.75%
Model weights saved for epoch 20
Epoch 20 - A_st_norm: 0.0387, J_w_norm: 0.0228


In [None]:
def classify_csv(model, checkpoint_epoch, input_csv_path, output_csv_path):
    """
    Loads the saved weights for self.G and self.C from a given checkpoint epoch,
    performs classification on an input CSV file, and saves the predictions to a new CSV.

    Args:
        model (DomainAdaptation): Instance of the DomainAdaptation class.
        checkpoint_epoch (int): The epoch number from which to load the checkpoint.
        input_csv_path (str): Path to the CSV file containing the features (dim_0 to dim_1023).
        output_csv_path (str): Path to save the new CSV file with predictions.
    """
    # Load the checkpoint (this updates self.G and self.C)
    model.load_checkpoint(checkpoint_epoch)

    # Set models to evaluation mode
    model.G.eval()
    model.C.eval()

    # Read CSV file containing features. We expect columns "dim_0" ... "dim_1023".
    import pandas as pd
    import numpy as np
    import torch

    df = pd.read_csv(input_csv_path)
    feature_cols = [f"dim_{i}" for i in range(1024)]

    # Check for missing columns
    missing_cols = [col for col in feature_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Input CSV is missing the following columns: {missing_cols}")

    # Convert feature data to a torch tensor
    features = torch.tensor(df[feature_cols].values.astype(np.float32))
    features = features.to(device)  # 'device' is defined earlier in your code

    # Perform forward pass without tracking gradients
    with torch.no_grad():
        # Extract features using the feature extractor
        extracted_features = model.G(features)
        # Get classifier outputs using the classifier network
        outputs = model.C(extracted_features)
        # Convert outputs to binary predictions using a 0.5 threshold
        predictions = (outputs >= 0.5).float().cpu().numpy()

    # Append predictions to the DataFrame and save to a new CSV file
    df["prediction"] = predictions
    columns_to_drop = [f"dim_{i}" for i in range(1024)]

    # Drop the columns (only those that exist in the DataFrame)
    df = df.drop(columns=[col for col in columns_to_drop if col in df.columns])
    df.to_csv(output_csv_path, index=False)

    print(f"Predictions saved in: {output_csv_path}")


# Then, after training and saving weights, classify new data as follows:
classify_csv(model, checkpoint_epoch=20, input_csv_path="/content/dev_merge.csv", output_csv_path="predictions.csv")


Loaded checkpoint from epoch 20
Predictions saved in: predictions.csv
