# Load Files and everything (only for colab)

In [1]:
!pip install pytorch_metric_learning

Collecting pytorch_metric_learning
  Downloading pytorch_metric_learning-2.6.0-py3-none-any.whl.metadata (17 kB)
Downloading pytorch_metric_learning-2.6.0-py3-none-any.whl (119 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.3/119.3 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch_metric_learning
Successfully installed pytorch_metric_learning-2.6.0


In [2]:
!git clone --single-branch --branch code-fixes https://github.com/ayhamo/SRP_Domain_Adaptation.git

Cloning into 'SRP_Domain_Adaptation'...
remote: Enumerating objects: 360, done.[K
remote: Counting objects: 100% (106/106), done.[K
remote: Compressing objects: 100% (75/75), done.[K
remote: Total 360 (delta 50), reused 80 (delta 30), pack-reused 254 (from 1)[K
Receiving objects: 100% (360/360), 464.33 MiB | 12.95 MiB/s, done.
Resolving deltas: 100% (65/65), done.
Updating files: 100% (218/218), done.


In [3]:
import os
os.chdir('/content/SRP_Domain_Adaptation/')

# Original RAINCOAT.py file

Given By the authors, had to make alot of tinkering to get it to work

## code

In [14]:
%%writefile /content/SRP_Domain_Adaptation/algorithms/RAINCOAT.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_metric_learning import losses

from models.loss import SinkhornDistance

class Algorithm(torch.nn.Module):
    """
    A subclass of Algorithm implements a domain adaptation algorithm.
    Subclasses should implement the update() method.
    """

    def __init__(self, configs):
        super(Algorithm, self).__init__()
        self.configs = configs
        self.cross_entropy = nn.CrossEntropyLoss()

    def update(self, *args, **kwargs):
        raise NotImplementedError

class classifier(nn.Module):
    def __init__(self, configs):
        super(classifier, self).__init__()
        model_output_dim = configs.out_dim
        self.logits = nn.Linear(model_output_dim, configs.num_classes, bias=False)
        self.tmp= 0.1

    def forward(self, x):
        predictions = self.logits(x)/self.tmp
        return predictions

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, fl=128):
        super(SpectralConv1d, self).__init__()

        """
        1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))
        self.pi = torch.acos(torch.zeros(1)).item() * 2

    # Complex multiplication
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x = torch.cos(x)
        x_ft = torch.fft.rfft(x,norm='ortho')
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1,  device=x.device, dtype=torch.cfloat)
        out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)
        r = out_ft[:, :, :self.modes1].abs()
        p = out_ft[:, :, :self.modes1].angle()
        return torch.concat([r,p],-1), out_ft


class CNN(nn.Module):
    def __init__(self, configs):
        super(CNN, self).__init__()
        self.width = configs.input_channels
        self.channel = configs.input_channels
        self.fl =   configs.sequence_len
        self.fc0 = nn.Linear(self.channel, self.width) # input channel is 2: (a(x), x)
        self.conv_block1 = nn.Sequential(
            nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size,
                      stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)),
            nn.BatchNorm1d(configs.mid_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
            nn.Dropout(configs.dropout)
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv1d(configs.mid_channels, configs.mid_channels , kernel_size=8, stride=1, bias=False, padding=4),
            nn.BatchNorm1d(configs.mid_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv1d(configs.mid_channels , configs.final_out_channels, kernel_size=8, stride=1, bias=False,
                      padding=4),
            nn.BatchNorm1d(configs.final_out_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
        )
        self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len)


    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block3(x)
        x = self.adaptive_pool(x)
        x_flat = x.reshape(x.shape[0], -1)
        return x_flat

class tf_encoder(nn.Module):
    def __init__(self, configs):
        super(tf_encoder, self).__init__()
        self.modes1 = configs.fourier_modes   # Number of low-frequency modes to keep
        self.width = configs.input_channels
        self.length =  configs.sequence_len
        self.freq_feature = SpectralConv1d(self.width, self.width, self.modes1,self.length)  # Frequency Feature Encoder
        self.bn_freq = nn.BatchNorm1d(configs.fourier_modes*2)   # It doubles because frequency features contain both amplitude and phase
        self.cnn = CNN(configs).to('cuda')  # Time Feature Encoder
        self.avg = nn.Conv1d(self.width, 1, kernel_size=3 ,
                  stride=configs.stride, bias=False, padding=(3 // 2))


    def forward(self, x):
        ef, out_ft = self.freq_feature(x)
        ef = F.relu(self.bn_freq(self.avg(ef).squeeze()))
        et = self.cnn(x)
        f = torch.concat([ef,et],-1)
        return F.normalize(f), out_ft

class tf_decoder(nn.Module):
    def __init__(self, configs):
        super(tf_decoder, self).__init__()
        self.input_channels, self.sequence_len = configs.input_channels, configs.sequence_len
        self.bn1 = nn.BatchNorm1d(self.input_channels,self.sequence_len)
        self.bn2 = nn.BatchNorm1d(self.input_channels,self.sequence_len)
        self.convT = torch.nn.ConvTranspose1d(configs.final_out_channels, self.sequence_len, self.input_channels, stride=1)
        self.modes = configs.fourier_modes

    def forward(self, f, out_ft):
        x_low = self.bn1(torch.fft.irfft(out_ft, n=128))   # reconstruct  time series by using low frequency frequency features
        et = f[:,self.modes*2:]
        x_high = F.relu(self.bn2(self.convT(et.unsqueeze(2)).permute(0,2,1))) # reconstruct time series by using time features for high frequency patterns.
        return x_low + x_high


class RAINCOAT(Algorithm):
    def __init__(self, configs, hparams, device):
        super(RAINCOAT, self).__init__(configs)
        self.feature_extractor = tf_encoder(configs).to(device)
        self.decoder = tf_decoder(configs).to(device)
        self.classifier = classifier(configs).to(device)

        self.optimizer = torch.optim.Adam(
            list(self.feature_extractor.parameters()) + \
                list(self.decoder.parameters())+\
                list(self.classifier.parameters()),
            lr=hparams["learning_rate"],
            weight_decay=hparams["weight_decay"]
        )
        self.coptimizer = torch.optim.Adam(
            list(self.feature_extractor.parameters())+list(self.decoder.parameters()),
            lr=1*hparams["learning_rate"],
            weight_decay=hparams["weight_decay"]
        )

        self.hparams = hparams
        self.recons = nn.L1Loss(reduction='sum').to(device)
        self.pi = torch.acos(torch.zeros(1)).item() * 2
        self.loss_func = losses.ContrastiveLoss(pos_margin=0.5)
        self.sink = SinkhornDistance(eps=1e-3, max_iter=1000, reduction='sum')

    def update(self, src_x, src_y, trg_x):

        self.optimizer.zero_grad()
        # Encode both source and target features via our time-frequency feature encoder
        src_feat, out_s = self.feature_extractor(src_x)
        trg_feat, out_t = self.feature_extractor(trg_x)
        # Decode extracted features to time series
        src_recon = self.decoder(src_feat, out_s)
        trg_recon = self.decoder(trg_feat, out_t)
        # Compute reconstruction loss
        recons = 1e-4 * (self.recons(src_recon, src_x) + self.recons(trg_recon, trg_x))
        recons.backward(retain_graph=True)
        # Compute alignment loss
        dr, _, _ = self.sink(src_feat, trg_feat)
        sink_loss = dr
        sink_loss.backward(retain_graph=True)
        # Compute classification loss
        src_pred = self.classifier(src_feat)
        loss_cls = self.cross_entropy(src_pred, src_y)
        loss_cls.backward(retain_graph=True)
        self.optimizer.step()
        return {'Src_cls_loss': loss_cls.item(),'Sink': sink_loss.item()}

    def correct(self,src_x, src_y, trg_x):
        self.coptimizer.zero_grad()
        src_feat, out_s = self.feature_extractor(src_x)
        trg_feat, out_t = self.feature_extractor(trg_x)
        src_recon = self.decoder(src_feat, out_s)
        trg_recon = self.decoder(trg_feat, out_t)
        recons = 1e-4 * (self.recons(trg_recon, trg_x) + self.recons(src_recon, src_x))
        recons.backward()
        self.coptimizer.step()
        return {'recon': recons.item()}

Overwriting /content/SRP_Domain_Adaptation/algorithms/RAINCOAT.py


# Fixed RAINCOAT.py file by us



1.   CNN time features had missing layers added
2.   Update (Aligment) step was using unweighted and wrong loss, was updated to include weights (a,b,c) as per paper and use correct summed losses
3. Added validation for acc/f1 evulation



## code

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_metric_learning import losses

from models.loss import SinkhornDistance

class Algorithm(torch.nn.Module):

    def __init__(self, configs):
        super(Algorithm, self).__init__()
        self.configs = configs
        self.cross_entropy = nn.CrossEntropyLoss()

    def update(self, *args, **kwargs):
        raise NotImplementedError

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, fl=128):
        super(SpectralConv1d, self).__init__()

        """
        1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))
        self.pi = torch.acos(torch.zeros(1)).item() * 2

    # Complex multiplication
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x = torch.cos(x)
        x_ft = torch.fft.rfft(x,norm='ortho')
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1,  device=x.device, dtype=torch.cfloat)
        out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)
        r = out_ft[:, :, :self.modes1].abs()
        p = out_ft[:, :, :self.modes1].angle()
        return torch.concat([r,p],-1), out_ft


class CNN(nn.Module):
    def __init__(self, configs):
        super(CNN, self).__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size,
                      stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)),
            nn.BatchNorm1d(configs.mid_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
            nn.Dropout(configs.dropout)
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv1d(configs.mid_channels, configs.mid_channels , kernel_size=8, stride=1, bias=False, padding=4),
            nn.BatchNorm1d(configs.mid_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1)
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv1d(configs.mid_channels , configs.final_out_channels, kernel_size=8, stride=1, bias=False,
                      padding=4),
            nn.BatchNorm1d(configs.final_out_channels),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
        )
        self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len)


    def forward(self, x):
        x = self.conv_block1(x)
        # added from models.py (missing layers)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.adaptive_pool(x)
        x_flat = x.reshape(x.shape[0], -1)
        return x_flat

class tf_encoder(nn.Module):
    def __init__(self, configs):
        super(tf_encoder, self).__init__()
        self.modes1 = configs.fourier_modes   # Number of low-frequency modes to keep
        self.width = configs.input_channels
        self.length =  configs.sequence_len

        self.freq_feature = SpectralConv1d(self.width, self.width, self.modes1,self.length)  # Frequency Feature Encoder
        self.bn_freq = nn.BatchNorm1d(configs.fourier_modes*2)   # It doubles because frequency features contain both amplitude and phase
        self.cnn = CNN(configs).to('cuda')  # Time Feature Encoder
        self.avg = nn.Conv1d(self.width, 1, kernel_size=3 ,
                  stride=configs.stride, bias=False, padding=(3 // 2))


    def forward(self, x):
        ef, out_ft = self.freq_feature(x)
        ef = F.relu(self.bn_freq(self.avg(ef).squeeze()))
        et = self.cnn(x)
        f = torch.concat([ef,et],-1)
        return F.normalize(f), out_ft

class tf_decoder(nn.Module):
    def __init__(self, configs):
        super(tf_decoder, self).__init__()
        self.input_channels, self.sequence_len = configs.input_channels, configs.sequence_len
        self.bn1 = nn.BatchNorm1d(self.input_channels,self.sequence_len)
        self.bn2 = nn.BatchNorm1d(self.input_channels,self.sequence_len)
        self.convT = torch.nn.ConvTranspose1d(configs.final_out_channels, self.sequence_len, self.input_channels, stride=1)
        self.modes = configs.fourier_modes

    def forward(self, f, out_ft):
        x_low = self.bn1(torch.fft.irfft(out_ft, n=128))   # reconstruct  time series by using low frequency frequency features
        et = f[:,self.modes*2:]
        x_high = F.relu(self.bn2(self.convT(et.unsqueeze(2)).permute(0,2,1))) # reconstruct time series by using time features for high frequency patterns.
        return x_low + x_high

class classifier(nn.Module):
    def __init__(self, configs):
        super(classifier, self).__init__()
        model_output_dim = configs.out_dim
        self.logits = nn.Linear(model_output_dim, configs.num_classes, bias=False)
        self.tmp= 0.1

    def forward(self, x):
        predictions = self.logits(x)/self.tmp
        return predictions

class RAINCOAT(Algorithm):
    def __init__(self, configs, hparams, device):
        super(RAINCOAT, self).__init__(configs)
        self.feature_extractor = tf_encoder(configs).to(device)
        self.decoder = tf_decoder(configs).to(device)
        self.classifier = classifier(configs).to(device)

        self.optimizer = torch.optim.Adam(
            list(self.feature_extractor.parameters()) + \
                list(self.decoder.parameters())+\
                list(self.classifier.parameters()),
            lr=hparams["learning_rate"],
            weight_decay=hparams["weight_decay"]
        )
        self.coptimizer = torch.optim.Adam(
            list(self.feature_extractor.parameters())+list(self.decoder.parameters()),
            lr=1*hparams["learning_rate"],
            weight_decay=hparams["weight_decay"]
        )

        self.hparams = hparams
        self.recons = nn.L1Loss(reduction='sum').to(device)
        self.pi = torch.acos(torch.zeros(1)).item() * 2
        self.loss_func = losses.ContrastiveLoss(pos_margin=0.5)
        self.sink = SinkhornDistance(eps=1e-3, max_iter=1000, reduction='sum')

    def align(self, src_x, src_y, trg_x):
        self.optimizer.zero_grad()

        # Encode both source and target features via our time-frequency feature encoder
        src_feat, out_s = self.feature_extractor(src_x)
        trg_feat, out_t = self.feature_extractor(trg_x)
        # Decode extracted features to time series
        src_recon = self.decoder(src_feat, out_s)
        trg_recon = self.decoder(trg_feat, out_t)

        # Compute reconstruction loss (added the 0.2 weight here as per paper)
        recons = 1e-4 * (self.recons(src_recon, src_x) + self.recons(trg_recon, trg_x))
        recons.backward(retain_graph=True)

        # Compute alignment loss
        dr, _, _ = self.sink(src_feat, trg_feat)
        sink_loss = dr
        sink_loss.backward(retain_graph=True)

        # Compute classification loss
        src_pred = self.classifier(src_feat)
        loss_cls = self.cross_entropy(src_pred, src_y)
        loss_cls.backward(retain_graph=True)

        # Compute weights
        a, b, c = 1, 1, 0.2
        total = a + b + c
        lambda1 = a / total
        lambda2 = b / total
        lambda3 = c / total

        # Compute total loss with weights
        total_loss = lambda1 * recons + lambda2 * sink_loss + lambda3 * loss_cls
        self.optimizer.step()

        return {
            'Total_loss': total_loss.item(),
            'Reconstruction_loss': recons.item(),
            'Alignment_loss': sink_loss.item(),
            'Classification_loss': loss_cls.item()
            }


    def correct(self,src_x, src_y, trg_x):
        self.coptimizer.zero_grad()

        src_feat, out_s = self.feature_extractor(src_x)
        trg_feat, out_t = self.feature_extractor(trg_x)

        src_recon = self.decoder(src_feat, out_s)
        trg_recon = self.decoder(trg_feat, out_t)

        recons = 1e-4 * (self.recons(trg_recon, trg_x) + self.recons(src_recon, src_x))
        recons.backward()
        self.coptimizer.step()

        return {'Correct_reconstruction_loss': recons.item()}


Overwriting /content/SRP_Domain_Adaptation/algorithms/RAINCOAT.py


# New RAINCOAT - Our Changes

v1:
- Fractional Fourier Transform ( in SpectralConv1d and new frft method )
    - FrFT enhances time series prediction by offering a parametric transformation that improves model performance, at no extra cost of computation
    - introduced in the paper "Fractional Fourier Transform in Time Series Prediction"
        - http://yoksis.bilkent.edu.tr/pdf/files/16189.pdf
- Updated Decoder to Handle new Transform
- Changing decoder to use magnitude and phase, capture different aspects of the signal and improve perfomance

v2:
- Improved upon CNN's
    - Changed Relu to Mish(Mish is smooth, self-regularized non-monotonic function [This means that the function’s behavior varies depending on the input values], and has properties that can help with gradient flow in deep networks), it outprefmored ReLU and leaky ReLU
        - Introduced in 2019 by paper: https://arxiv.org/abs/1908.08681
    - Changed some Maxpoolings to AdaptiveAvgPooling
    - Channels were increased in 2th and 3rd layers, reduced then to original by 4th conv block 
- Introduced AdamW (with weight) with OneCycleLR Scheduler for learning rate
    - The learning rate & weight starts at a low value, increases to a maximum value, and then decreases to a value much lower than the initial learning rate
    - Comes with Momentum Cycling, the momentum is cycled inversely to the learning rate, ie When the learning rate is high, the momentum is low, and vice versa
    - reduces training time by using large learning rate, and that it act as a form of regularization, reducing the need for other regularization methods, and also large learning rates regularize training by increasing gradient noise, which improves generalization.
        - Introduced in the paper “Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates".
        - https://arxiv.org/abs/1708.07120

v3:
- Introduced Parallel Computing Optimization (code now runs 50% faster from 10 mins to 4.50 mins for 5 runs)
- Introduced CPU computing along with cuda (prevoiusly cpu would crash the program)
- Optimized read/write abuse that the old code was doing, and replaced it with variables
- Grid search was setup to find best hyperparamters for the changes in v2 with parallel optimization
    - Hyperparamters include Epoch, Weight decay, Batch size, scheduler & Co-scheduler steps per epoch paramter, Fractional Fourier (a) paramter, diffrenet activision functions and layers in CNN

v4:
- Made a grid serach for loss weights in order to get best combination

Original values over 5 runs + (Avg original results):

Avg Accuracy: 72.37737009913369, Accuracy STD: 1.089658045279597


Avg F1: 0.5378341391111486, F1 STD: 0.02216485961311463

Fixed version values over 5 runs + (Avg Fixed results): (due to validation, seems thier results is overfitting)


Avg Accuracy: 57.83547537641728, Accuracy STD: 4.115469947796791, 

Avg F1: 0.3237270780575732, F1 STD: 0.021784388570534444

v2 and v3 updates explained above (Avg Final):

Avg Accuracy: 63.11333738107413, Accuracy STD: 2.699980524318425, 

Avg F1: 0.47183228003868677, F1 STD: 0.01960662122499978

TODO:

Add training and validation curves

Add inception time instead of CNN's


## code

In [None]:
%%writefile /content/SRP_Domain_Adaptation/algorithms/RAINCOAT.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_metric_learning import losses
import torch.fft

from .loss import SinkhornDistance

class Algorithm(torch.nn.Module):

    def __init__(self, configs):
        super(Algorithm, self).__init__()
        self.configs = configs
        self.cross_entropy = nn.CrossEntropyLoss()

    def update(self, *args, **kwargs):
        raise NotImplementedError

# Fractional Fourier Transform, better perfomance
# http://yoksis.bilkent.edu.tr/pdf/files/16189.pdf

class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1, fraction_order):
        super(SpectralConv1d, self).__init__()
        """
        1D Fractional Fourier layer. It does FrFT, linear transform, and Inverse FrFT.    
        """
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  # Number of Fourier modes to multiply, at most floor(N/2) + 1
        self.fraction_order = fraction_order  # Fractional order for FrFT

        self.scale = (1 / (in_channels * out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))
        self.pi = torch.acos(torch.zeros(1)).item() * 2

    # Complex multiplication
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def frft(self, x, a):
        """
        Perform the Fractional Fourier Transform (FrFT) on the input tensor x with order a.
        """
        N = x.shape[-1]
        k = torch.arange(0, N, device=x.device)
        exp_term = torch.exp(-1j * self.pi * a * k**2 / N)
        x_ft = torch.fft.fft(x)
        x_frft = torch.fft.ifft(x_ft * exp_term)
        return x_frft

    def forward(self, x):
        batchsize = x.shape[0]
        # Compute Fractional Fourier coefficients up to factor of e^(- something constant)
        x = torch.cos(x)
        x_frft = self.frft(x, self.fraction_order)
        out_frft = torch.zeros(batchsize, self.out_channels, x_frft.size(-1), device=x.device, dtype=torch.cfloat)
        out_frft[:, :, :self.modes1] = self.compl_mul1d(x_frft[:, :, :self.modes1], self.weights1)
        r = out_frft[:, :, :self.modes1].abs()
        p = out_frft[:, :, :self.modes1].angle()
        return torch.concat([r, p], -1), out_frft


class CNN(nn.Module):
    def __init__(self, configs):
        super(CNN, self).__init__()

        # Mish: A Self Regularized Non-Monotonic Activation Function
        # https://arxiv.org/abs/1908.08681

        self.conv_block1 = nn.Sequential(
            nn.Conv1d(configs.input_channels, configs.mid_channels, kernel_size=configs.kernel_size,
                      stride=configs.stride, bias=False, padding=(configs.kernel_size // 2)),
            nn.BatchNorm1d(configs.mid_channels),
            nn.Mish(),
            nn.AdaptiveAvgPool1d(configs.features_len),
            nn.Dropout(configs.dropout)
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv1d(configs.mid_channels, configs.mid_channels*2, kernel_size=8, stride=1, bias=False, padding=4),
            nn.BatchNorm1d(configs.mid_channels*2),
            nn.Mish(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
        )

        self.conv_block3 = nn.Sequential(
            nn.Conv1d(configs.mid_channels*2, configs.mid_channels*2, kernel_size=8, stride=1, bias=False, padding=4),
            nn.BatchNorm1d(configs.mid_channels*2),
            nn.Mish(),
            nn.AdaptiveAvgPool1d(configs.features_len),
        )

        # New convolutional block that reduces channels back to final_out_channels
        self.conv_block4 = nn.Sequential(
            nn.Conv1d(configs.mid_channels*2, configs.final_out_channels, kernel_size=8, stride=1, bias=False, padding=4),
            nn.BatchNorm1d(configs.final_out_channels),
            nn.Mish(),
            nn.MaxPool1d(kernel_size=2, stride=2, padding=1),
        )

        self.adaptive_pool = nn.AdaptiveAvgPool1d(configs.features_len)


    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = self.conv_block4(x)
        
        x = self.adaptive_pool(x)

        x_flat = x.reshape(x.shape[0], -1)
        return x_flat

class tf_encoder(nn.Module):
    def __init__(self, configs):
        super(tf_encoder, self).__init__()
        self.modes1 = configs.fourier_modes   # Number of low-frequency modes to keep
        self.width = configs.input_channels
        self.length =  configs.sequence_len
        self.fraction_order = configs.fraction_order

        self.freq_feature = SpectralConv1d(self.width, self.width, self.modes1, self.fraction_order)  # Frequency Feature Encoder
        self.bn_freq = nn.BatchNorm1d(configs.fourier_modes*2)   # It doubles because frequency features contain both amplitude and phase
        self.cnn = CNN(configs)  # Time Feature Encoder
        self.avg = nn.Conv1d(self.width, 1, kernel_size=3 ,
                  stride=configs.stride, bias=False, padding=(3 // 2))


    def forward(self, x):
        ef, out_ft = self.freq_feature(x)
        ef = F.relu(self.bn_freq(self.avg(ef).squeeze()))
        et = self.cnn(x)
        f = torch.concat([ef,et],-1)
        return F.normalize(f), out_ft

class tf_decoder(nn.Module):
    def __init__(self, configs):
        super(tf_decoder, self).__init__()
        self.input_channels, self.sequence_len = configs.input_channels, configs.sequence_len
        self.bn1 = nn.BatchNorm1d(self.input_channels,self.sequence_len)
        self.bn2 = nn.BatchNorm1d(self.input_channels,self.sequence_len)
        self.convT = torch.nn.ConvTranspose1d(configs.final_out_channels, self.sequence_len, self.input_channels, stride=1)
        self.modes = configs.fourier_modes

        self.channel_proj = nn.Conv1d(2 * self.input_channels, self.input_channels, kernel_size=1)

        self.fraction_order = configs.fraction_order
        self.pi = torch.acos(torch.zeros(1)).item() * 2

    def inverse_frft(self, x, a):
        """
        Perform the inverse Fractional Fourier Transform (FrFT) on the input tensor x with order a.
        """
        N = x.shape[-1]
        k = torch.arange(0, N, device=x.device)
        exp_term = torch.exp(1j * self.pi * a * k**2 / N)  # Note the positive sign for inverse
        x_ft = torch.fft.fft(x)
        x_inv_frft = torch.fft.ifft(x_ft * exp_term)
        return x_inv_frft
        
    def forward(self, f, out_ft):
        # Reconstruct time series by using low frequency features from FrFT
        x_low_complex = self.inverse_frft(out_ft, self.fraction_order)

        amplitude = x_low_complex.abs()
        phase = x_low_complex.angle()

        x_low = torch.cat([amplitude, phase], dim=1)
        x_low = self.channel_proj(x_low)
        x_low = self.bn1(x_low)

        et = f[:, self.modes * 2:]

        # Reconstruct time series by using time features for high frequency patterns
        x_high = F.relu(self.bn2(self.convT(et.unsqueeze(2)).permute(0, 2, 1)))

        return x_low + x_high

class classifier(nn.Module):
    def __init__(self, configs):
        super(classifier, self).__init__()
        model_output_dim = configs.out_dim
        self.logits = nn.Linear(model_output_dim, configs.num_classes, bias=False)
        self.tmp= 0.1

    def forward(self, x):
        predictions = self.logits(x)/self.tmp
        return predictions

class RAINCOAT(Algorithm):
    def __init__(self, configs, hparams, device):
        super(RAINCOAT, self).__init__(configs)
        self.feature_extractor = tf_encoder(configs).to(device)
        self.decoder = tf_decoder(configs).to(device)
        self.classifier = classifier(configs).to(device)

        self.optimizer = torch.optim.AdamW(
            list(self.feature_extractor.parameters()) + \
                list(self.decoder.parameters())+\
                list(self.classifier.parameters()),
            lr=hparams["learning_rate"],
            weight_decay=hparams["weight_decay"]
        )

        # Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates".
        # https://arxiv.org/abs/1708.07120

        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            cycle_momentum = True,
            max_lr= 1e-2,
            steps_per_epoch=hparams["scheduler_steps"],
            epochs=hparams["num_epochs"]
        )

        self.coptimizer = torch.optim.AdamW(
            list(self.feature_extractor.parameters())+list(self.decoder.parameters()),
            lr=1*hparams["learning_rate"],
            weight_decay=hparams["weight_decay"]
        )
        
        self.coscheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.coptimizer,
            cycle_momentum = True,
            max_lr= 1e-2,
            steps_per_epoch=hparams["coscheduler_steps"],
            epochs=hparams["corr_epochs"]
        )

        self.hparams = hparams
        self.recons = nn.L1Loss(reduction='sum').to(device)
        self.pi = torch.acos(torch.zeros(1)).item() * 2
        self.loss_func = losses.ContrastiveLoss(pos_margin=0.5)
        self.sink = SinkhornDistance(eps=1e-3, max_iter=1000, reduction='sum', device=device)

    def align(self, src_x, src_y, trg_x):
        self.optimizer.zero_grad()

        # Encode both source and target features via our time-frequency feature encoder
        src_feat, out_s = self.feature_extractor(src_x)
        trg_feat, out_t = self.feature_extractor(trg_x)
        # Decode extracted features to time series
        src_recon = self.decoder(src_feat, out_s)
        trg_recon = self.decoder(trg_feat, out_t)

        # Compute reconstruction loss (added the 0.2 weight here as per paper)
        recons = 1e-4 * (self.recons(src_recon, src_x) + self.recons(trg_recon, trg_x))
        recons.backward(retain_graph=True)

        # Compute alignment loss
        dr, _, _ = self.sink(src_feat, trg_feat)
        sink_loss = dr
        sink_loss.backward(retain_graph=True)

        # Compute classification loss
        src_pred = self.classifier(src_feat)
        loss_cls = self.cross_entropy(src_pred, src_y)
        loss_cls.backward(retain_graph=True)

        # Compute weights
        a, b, c = 1, 1, 0.2
        total = a + b + c
        lambda1 = a / total
        lambda2 = b / total
        lambda3 = c / total

        # Compute total loss with weights
        total_loss = lambda1 * recons + lambda2 * sink_loss + lambda3 * loss_cls
        
        self.optimizer.step()

        return {
            'Total_loss': total_loss.item(),
            'Reconstruction_loss': recons.item(),
            'Alignment_loss': sink_loss.item(),
            'Classification_loss': loss_cls.item()
            }


    def correct(self,src_x, src_y, trg_x):
        self.coptimizer.zero_grad()

        src_feat, out_s = self.feature_extractor(src_x)
        trg_feat, out_t = self.feature_extractor(trg_x)

        src_recon = self.decoder(src_feat, out_s)
        trg_recon = self.decoder(trg_feat, out_t)

        recons = 1e-4 * (self.recons(trg_recon, trg_x) + self.recons(src_recon, src_x))
        recons.backward()

        self.coptimizer.step()

        return {'Correct_reconstruction_loss': recons.item()}



Overwriting /content/SRP_Domain_Adaptation/algorithms/RAINCOAT.py


## Loss weight Grid Search

Best weights are: 1 for cross entropy, 1 for sinkhorn and 0.2 for Reconstruction

### Code and Results

In [None]:
import pandas as pd
import subprocess
from itertools import product

# Define the parameter grid
param_grid = {
    'a': [1, 0.9, 0.8, 0.7],
    'b': [0.5, 0.6, 0.7, 0.8, 0.9, 1],
    'c': [0.1, 0.2, 0.3, 0.4, 0.5]
}

# Create combinations of parameters
param_combinations = list(product(param_grid['a'], param_grid['b'], param_grid['c']))

# Initialize a list to store the results
results = []

for a, b, c in param_combinations:
    # Read the align method file
    with open('/kaggle/working/SRP_Domain_Adaptation/Raincoat/algorithms/RAINCOAT.py', 'r') as file:
        data = file.readlines()

    # Modify the file with the current parameters
    for i, line in enumerate(data):
        if 'a=' in line:
            data[i] = f'        a={a}\n'
        if 'b=' in line:
            data[i] = f'        b={b}\n'
        if 'c=' in line:
            data[i] = f'        c={c}\n'
    
    with open('/kaggle/working/SRP_Domain_Adaptation/Raincoat/algorithms/RAINCOAT.py', 'w') as file:
        file.writelines(data)

    # Run the script
    result = subprocess.run(['python3', '/kaggle/working/SRP_Domain_Adaptation/Raincoat/main.py', '--experiment_description', 'WISDM', '--dataset', 'WISDM', '--num_runs', '5', '--device', 'cuda'], capture_output=True, text=True)

    # Check if the script ran successfully
    if result.returncode != 0:
        print(f"Error running script for a={a}, b={b}, c={c}: {result.stderr}")
        continue

    # Read the results
    file_path = '/kaggle/working/SRP_Domain_Adaptation/experiments_logs/RAINCOAT ClosedSet/WISDM/average_correct.csv'
    df = pd.read_csv(file_path)
    last_row = df.tail(1).values[0]

    avg_accuracy = last_row[0].split(': ')[-1]
    accuracy_std = last_row[1].split(': ')[-1]
    avg_f1 = last_row[2].split(': ')[-1]
    f1_std = last_row[3].split(': ')[-1]

    # Append the results to the list
    results.append([a, b, c, avg_accuracy, accuracy_std, avg_f1, f1_std])

    print(f"a={a}, b={b}, c={c}: Avg Accuracy: {avg_accuracy}, Accuracy STD: {accuracy_std}, Avg F1: {avg_f1}, F1 STD: {f1_std}")

# Save the results to a CSV file
results_df = pd.DataFrame(results, columns=['a', 'b', 'c', 'avg_accuracy', 'accuracy_std', 'avg_f1', 'f1_std'])
results_df.to_csv('/kaggle/working/SRP_Domain_Adaptation/results.csv', index=False)



a=1, b=0.5, c=0.1: Avg Accuracy: 58.383162664285805, Accuracy STD: 4.344249233941357, Avg F1: 0.4270006275820081, F1 STD: 0.04109196239089278

a=1, b=0.5, c=0.2: Avg Accuracy: 58.25233111523833, Accuracy STD: 3.9667598886217985, Avg F1: 0.42206702524527734, F1 STD: 0.027860881973400045

a=1, b=0.5, c=0.3: Avg Accuracy: 57.830205511713984, Accuracy STD: 3.4621362894888112, Avg F1: 0.4138808260628923, F1 STD: 0.028217966542675266

a=1, b=0.5, c=0.4: Avg Accuracy: 56.99284308308046, Accuracy STD: 2.6399024710907257, Avg F1: 0.405147444054285, F1 STD: 0.024290056291302327

a=1, b=0.5, c=0.5: Avg Accuracy: 56.88666320278086, Accuracy STD: 3.203043351602124, Avg F1: 0.4037268640144133, F1 STD: 0.03049348607221128

a=1, b=0.6, c=0.1: Avg Accuracy: 58.80824482832546, Accuracy STD: 3.7210310999410603, Avg F1: 0.4347568830893806, F1 STD: 0.03177890122906314

a=1, b=0.6, c=0.2: Avg Accuracy: 58.47581537148238, Accuracy STD: 3.885758949380515, Avg F1: 0.4222068458662476, F1 STD: 0.0268267865915538

a=1, b=0.6, c=0.3: Avg Accuracy: 58.01217743063298, Accuracy STD: 2.461070061761912, Avg F1: 0.41903939316065497, F1 STD: 0.020261982674369926

a=1, b=0.6, c=0.4: Avg Accuracy: 57.05835569683897, Accuracy STD: 2.9926436623337143, Avg F1: 0.40825563051369046, F1 STD: 0.02502208000194788

a=1, b=0.6, c=0.5: Avg Accuracy: 56.49558940001914, Accuracy STD: 3.3013457954073795, Avg F1: 0.40349812791362094, F1 STD: 0.026637871332063846

a=1, b=0.7, c=0.1: Avg Accuracy: 58.90152979330345, Accuracy STD: 3.501267397439612, Avg F1: 0.4286373683363638, F1 STD: 0.03653971322340406

a=1, b=0.7, c=0.2: Avg Accuracy: 57.77648504713029, Accuracy STD: 3.4611467388663817, Avg F1: 0.41321538721445245, F1 STD: 0.020774353824323206

a=1, b=0.7, c=0.3: Avg Accuracy: 57.94463334508835, Accuracy STD: 3.5174775624298604, Avg F1: 0.41714414265846667, F1 STD: 0.030673160551712594

a=1, b=0.7, c=0.4: Avg Accuracy: 56.950773510158434, Accuracy STD: 3.1852197613277777, Avg F1: 0.404435528254066, F1 STD: 0.028875789590007434

a=1, b=0.7, c=0.5: Avg Accuracy: 56.65084575420171, Accuracy STD: 3.361942539719438, Avg F1: 0.409604234669083, F1 STD: 0.029433505334946423

a=1, b=0.8, c=0.1: Avg Accuracy: 58.28440935551056, Accuracy STD: 3.6271241017383193, Avg F1: 0.42336515210291764, F1 STD: 0.035173311060751014

a=1, b=0.8, c=0.2: Avg Accuracy: 58.49591106330085, Accuracy STD: 3.4544904958389484, Avg F1: 0.42185743561824796, F1 STD: 0.0275480584543292

a=1, b=0.8, c=0.3: Avg Accuracy: 58.05173564364717, Accuracy STD: 3.183995178087485, Avg F1: 0.4139348612207547, F1 STD: 0.023503396812714766

a=1, b=0.8, c=0.4: Avg Accuracy: 57.27915906780221, Accuracy STD: 2.657008324840166, Avg F1: 0.41076497178821325, F1 STD: 0.02023638639577716

a=1, b=0.8, c=0.5: Avg Accuracy: 56.430383454964684, Accuracy STD: 3.462411450691987, Avg F1: 0.40507696593306813, F1 STD: 0.029144303271249995

a=1, b=0.9, c=0.1: Avg Accuracy: 58.26942451070729, Accuracy STD: 3.4135156218742346, Avg F1: 0.42719145632730393, F1 STD: 0.032480711734862465

a=1, b=0.9, c=0.2: Avg Accuracy: 57.86070802077487, Accuracy STD: 3.5381461500051885, Avg F1: 0.41488747207333054, F1 STD: 0.022318490700175993

a=1, b=0.9, c=0.3: Avg Accuracy: 57.91580791640089, Accuracy STD: 3.675332822065747, Avg F1: 0.4157426677877726, F1 STD: 0.025009173613896023

a=1, b=0.9, c=0.4: Avg Accuracy: 57.17962366834149, Accuracy STD: 2.708355634390835, Avg F1: 0.40959842449344724, F1 STD: 0.02133063458357955

a=1, b=0.9, c=0.5: Avg Accuracy: 56.98141435434583, Accuracy STD: 3.1138438286802734, Avg F1: 0.4007095651938597, F1 STD: 0.017988537532722167

a=1, b=1, c=0.1: Avg Accuracy: 58.602669068638626, Accuracy STD: 3.0909292723951687, Avg F1: 0.4255514737311076, F1 STD: 0.03557826160735715

a=1, b=1, c=0.2: Avg Accuracy: 58.12893136115904, Accuracy STD: 4.131729209063483, Avg F1: 0.42081049266552484, F1 STD: 0.03287129432132202

a=1, b=1, c=0.3: Avg Accuracy: 58.08922638568242, Accuracy STD: 3.1863543600923947, Avg F1: 0.4203022525483607, F1 STD: 0.02573935006742992

a=1, b=1, c=0.4: Avg Accuracy: 56.91620333478001, Accuracy STD: 2.77264776080713, Avg F1: 0.40643589493755383, F1 STD: 0.0189429496363941

a=1, b=1, c=0.5: Avg Accuracy: 57.44074900126062, Accuracy STD: 3.014475449320494, Avg F1: 0.41350830667820376, F1 STD: 0.027389740188901897

a=0.9, b=0.5, c=0.1: Avg Accuracy: 58.55471447266733, Accuracy STD: 3.796812712051073, Avg F1: 0.4271512186424453, F1 STD: 0.029665566385341628

a=0.9, b=0.5, c=0.2: Avg Accuracy: 58.33809430337668, Accuracy STD: 3.374188137830791, Avg F1: 0.4254772056022834, F1 STD: 0.028924051981334718

a=0.9, b=0.5, c=0.3: Avg Accuracy: 56.94437339971944, Accuracy STD: 3.165970725870329, Avg F1: 0.4043148347169837, F1 STD: 0.022456026890381005

a=0.9, b=0.5, c=0.4: Avg Accuracy: 57.29817544748602, Accuracy STD: 2.523323413489675, Avg F1: 0.4064599226246096, F1 STD: 0.02720214555306415

a=0.9, b=0.5, c=0.5: Avg Accuracy: 56.61246119031587, Accuracy STD: 3.657429770236068, Avg F1: 0.4042200707455682, F1 STD: 0.027418890373965037

a=0.9, b=0.6, c=0.1: Avg Accuracy: 58.11474431330065, Accuracy STD: 3.9163496401226445, Avg F1: 0.4222164445301223, F1 STD: 0.041401646192502105

a=0.9, b=0.6, c=0.2: Avg Accuracy: 57.69499589923741, Accuracy STD: 3.1949851776378115, Avg F1: 0.4159625457117918, F1 STD: 0.022795435506056783

a=0.9, b=0.6, c=0.3: Avg Accuracy: 57.57069194128826, Accuracy STD: 2.7395597662679503, Avg F1: 0.4079660679071667, F1 STD: 0.026860056446020675

a=0.9, b=0.6, c=0.4: Avg Accuracy: 57.00179251554802, Accuracy STD: 3.088935245221947, Avg F1: 0.410563223548138, F1 STD: 0.02731928946046437

a=0.9, b=0.6, c=0.5: Avg Accuracy: 56.55163052634579, Accuracy STD: 3.205004593599423, Avg F1: 0.40304470562839717, F1 STD: 0.028493094245384876

a=0.9, b=0.7, c=0.1: Avg Accuracy: 58.09617083462575, Accuracy STD: 2.9401272783773855, Avg F1: 0.4189534290567397, F1 STD: 0.024022926533345798

a=0.9, b=0.7, c=0.2: Avg Accuracy: 58.11492018015629, Accuracy STD: 3.327398093992909, Avg F1: 0.4152932138590826, F1 STD: 0.02356062363873845

a=0.9, b=0.7, c=0.3: Avg Accuracy: 57.612343063447476, Accuracy STD: 2.7650989545540403, Avg F1: 0.4083880391703767, F1 STD: 0.019723259865790637

a=0.9, b=0.7, c=0.4: Avg Accuracy: 56.84919499880463, Accuracy STD: 3.056169977921685, Avg F1: 0.4065824331102264, F1 STD: 0.029902894498653822

a=0.9, b=0.7, c=0.5: Avg Accuracy: 56.877432463969754, Accuracy STD: 2.98187020416884, Avg F1: 0.41107921579610174, F1 STD: 0.025935572695270578

a=0.9, b=0.8, c=0.1: Avg Accuracy: 57.7455176637025, Accuracy STD: 4.030269065014441, Avg F1: 0.41508470352674537, F1 STD: 0.03906594671867556

a=0.9, b=0.8, c=0.2: Avg Accuracy: 58.118472372105565, Accuracy STD: 3.387386553771265, Avg F1: 0.41690411003670624, F1 STD: 0.02500455766784725

a=0.9, b=0.8, c=0.3: Avg Accuracy: 57.6214067634653, Accuracy STD: 3.3177640219729096, Avg F1: 0.4159800935006002, F1 STD: 0.029469139648801282

a=0.9, b=0.8, c=0.4: Avg Accuracy: 56.913574782455626, Accuracy STD: 3.5845915311001812, Avg F1: 0.40355843118876883, F1 STD: 0.032229229056214115

a=0.9, b=0.8, c=0.5: Avg Accuracy: 56.404040878480224, Accuracy STD: 3.4626957256117126, Avg F1: 0.40327597330352083, F1 STD: 0.024576646192304798

a=0.9, b=0.9, c=0.1: Avg Accuracy: 58.31414401512609, Accuracy STD: 3.1296883415443415, Avg F1: 0.42705150676570514, F1 STD: 0.02562756382525936

a=0.9, b=0.9, c=0.2: Avg Accuracy: 58.81928971361392, Accuracy STD: 2.9609321891578126, Avg F1: 0.41810598260226073, F1 STD: 0.02054164845730631

a=0.9, b=0.9, c=0.3: Avg Accuracy: 58.74595644807285, Accuracy STD: 2.9947823442879105, Avg F1: 0.4200832156753411, F1 STD: 0.028320370246559225

a=0.9, b=0.9, c=0.4: Avg Accuracy: 56.939451243460795, Accuracy STD: 3.2016252311359183, Avg F1: 0.4062825366169284, F1 STD: 0.02221360750900834

a=0.9, b=0.9, c=0.5: Avg Accuracy: 57.55257737224048, Accuracy STD: 4.153525677254753, Avg F1: 0.4118020627327281, F1 STD: 0.027386221062247557

a=0.9, b=1, c=0.1: Avg Accuracy: 59.27935383521035, Accuracy STD: 3.813521425746436, Avg F1: 0.4273872801931386, F1 STD: 0.034769370959234755

a=0.9, b=1, c=0.2: Avg Accuracy: 57.85146598859876, Accuracy STD: 2.1671990655866904, Avg F1: 0.4166425292747221, F1 STD: 0.022463162498398793

a=0.9, b=1, c=0.3: Avg Accuracy: 57.92922157195028, Accuracy STD: 3.1598518459653775, Avg F1: 0.41501245671965104, F1 STD: 0.023340473268629994

a=0.9, b=1, c=0.4: Avg Accuracy: 56.88287061340729, Accuracy STD: 3.249755476213485, Avg F1: 0.40882599025996724, F1 STD: 0.025191392515404126

a=0.9, b=1, c=0.5: Avg Accuracy: 56.88109889992562, Accuracy STD: 3.4419701239493037, Avg F1: 0.4040362722802248, F1 STD: 0.027195172392633183

a=0.8, b=0.5, c=0.1: Avg Accuracy: 58.09957253930995, Accuracy STD: 4.00933717255157, Avg F1: 0.41969666274654094, F1 STD: 0.03395586498893456

a=0.8, b=0.5, c=0.2: Avg Accuracy: 57.51777854860645, Accuracy STD: 3.331121069435474, Avg F1: 0.4084880789608204, F1 STD: 0.02107672788771445

a=0.8, b=0.5, c=0.3: Avg Accuracy: 57.24729205932296, Accuracy STD: 2.8890606998771404, Avg F1: 0.41004567622700555, F1 STD: 0.029234064695510254

a=0.8, b=0.5, c=0.4: Avg Accuracy: 56.452909578805375, Accuracy STD: 2.959648529373006, Avg F1: 0.4024000243764051, F1 STD: 0.027400657142757176

a=0.8, b=0.5, c=0.5: Avg Accuracy: 56.63692777731078, Accuracy STD: 2.778882457040634, Avg F1: 0.40231698318361675, F1 STD: 0.02294102506997958

a=0.8, b=0.6, c=0.1: Avg Accuracy: 58.301349596180174, Accuracy STD: 3.8120458530167194, Avg F1: 0.4136129767537512, F1 STD: 0.034703479662705795

a=0.8, b=0.6, c=0.2: Avg Accuracy: 58.24465274499281, Accuracy STD: 2.8566224712829236, Avg F1: 0.41608459473246207, F1 STD: 0.022867449549547277

a=0.8, b=0.6, c=0.3: Avg Accuracy: 56.977486602908115, Accuracy STD: 3.3134222384298804, Avg F1: 0.40498895948803354, F1 STD: 0.023089846687551173

a=0.8, b=0.6, c=0.4: Avg Accuracy: 56.737770051848884, Accuracy STD: 3.733332793749419, Avg F1: 0.4072108024653505, F1 STD: 0.03565493075485929

a=0.8, b=0.6, c=0.5: Avg Accuracy: 56.35786759033732, Accuracy STD: 2.723993841921827, Avg F1: 0.39569109549102166, F1 STD: 0.02438876579017663

a=0.8, b=0.7, c=0.1: Avg Accuracy: 57.83911588936462, Accuracy STD: 4.080726539366856, Avg F1: 0.41834333780342686, F1 STD: 0.028100441063228624

a=0.8, b=0.7, c=0.2: Avg Accuracy: 57.55222555654194, Accuracy STD: 2.703154747847519, Avg F1: 0.410953022572746, F1 STD: 0.023241871563773803

a=0.8, b=0.7, c=0.3: Avg Accuracy: 57.62757768381302, Accuracy STD: 3.527100534838787, Avg F1: 0.41788614151126824, F1 STD: 0.030251658512207253

a=0.8, b=0.7, c=0.4: Avg Accuracy: 56.894267265508006, Accuracy STD: 2.8813334120499596, Avg F1: 0.410788605668351, F1 STD: 0.029749770008659977

a=0.8, b=0.7, c=0.5: Avg Accuracy: 57.14690327394351, Accuracy STD: 2.7972775373381773, Avg F1: 0.4107374231127213, F1 STD: 0.033515004056860684

a=0.8, b=0.8, c=0.1: Avg Accuracy: 57.91439124999408, Accuracy STD: 3.23883051088257, Avg F1: 0.4143914768119942, F1 STD: 0.025965552172091077

a=0.8, b=0.8, c=0.2: Avg Accuracy: 58.084896939054524, Accuracy STD: 3.4556917509322838, Avg F1: 0.41720750298703235, F1 STD: 0.02461970884775678

a=0.8, b=0.8, c=0.3: Avg Accuracy: 57.14344005802801, Accuracy STD: 3.1843467777638637, Avg F1: 0.40628556908150604, F1 STD: 0.024237010111923155

a=0.8, b=0.8, c=0.4: Avg Accuracy: 57.003697804119305, Accuracy STD: 3.123330008155274, Avg F1: 0.40690420341680394, F1 STD: 0.023355864359835048

a=0.8, b=0.8, c=0.5: Avg Accuracy: 56.8513388712701, Accuracy STD: 3.021599746835814, Avg F1: 0.40352164999794127, F1 STD: 0.026281549255685575

a=0.8, b=0.9, c=0.1: Avg Accuracy: 58.41371278139767, Accuracy STD: 2.9198493064132625, Avg F1: 0.42499396057221206, F1 STD: 0.031062159078401336

a=0.8, b=0.9, c=0.2: Avg Accuracy: 58.386448095643665, Accuracy STD: 2.9792276905413346, Avg F1: 0.41550826171998806, F1 STD: 0.025822473230062596

a=0.8, b=0.9, c=0.3: Avg Accuracy: 58.633014244092735, Accuracy STD: 4.351633708750097, Avg F1: 0.4235178229085149, F1 STD: 0.0341422951202873

a=0.8, b=0.9, c=0.4: Avg Accuracy: 56.8863725319546, Accuracy STD: 4.174054787195522, Avg F1: 0.41600883218095763, F1 STD: 0.03003509972122211

a=0.8, b=0.9, c=0.5: Avg Accuracy: 57.03710474626295, Accuracy STD: 3.174136534921053, Avg F1: 0.40846494079358164, F1 STD: 0.029184598534204166

a=0.8, b=1, c=0.1: Avg Accuracy: 58.473105142479675, Accuracy STD: 2.810349344802369, Avg F1: 0.4274772643782649, F1 STD: 0.029368188786359634

a=0.8, b=1, c=0.2: Avg Accuracy: 58.07236470126345, Accuracy STD: 3.1783852603654585, Avg F1: 0.41657096174992125, F1 STD: 0.02880847514043164

a=0.8, b=1, c=0.3: Avg Accuracy: 58.15438092367284, Accuracy STD: 2.9302220896196127, Avg F1: 0.41696701911632417, F1 STD: 0.023146571775737985

a=0.8, b=1, c=0.4: Avg Accuracy: 57.35404822375524, Accuracy STD: 3.4182792834007456, Avg F1: 0.4052234750235032, F1 STD: 0.020023638573461084

a=0.8, b=1, c=0.5: Avg Accuracy: 56.984189209735725, Accuracy STD: 3.512535977785808, Avg F1: 0.4056909540819733, F1 STD: 0.029729487417452297

a=0.7, b=0.5, c=0.1: Avg Accuracy: 57.89694641635428, Accuracy STD: 3.403238734381127, Avg F1: 0.41667341948157155, F1 STD: 0.03361396435169525

a=0.7, b=0.5, c=0.2: Avg Accuracy: 57.81585459495047, Accuracy STD: 2.8570810076704887, Avg F1: 0.40798407550403065, F1 STD: 0.020391532228292684

a=0.7, b=0.5, c=0.3: Avg Accuracy: 56.63424992111446, Accuracy STD: 3.1305899661255854, Avg F1: 0.4055054209981484, F1 STD: 0.026146501831280877

a=0.7, b=0.5, c=0.4: Avg Accuracy: 56.60150574412084, Accuracy STD: 3.1722585701196975, Avg F1: 0.405582652764665, F1 STD: 0.029102105654662484

a=0.7, b=0.5, c=0.5: Avg Accuracy: 56.505014185341985, Accuracy STD: 3.7763303827947117, Avg F1: 0.40469128898090256, F1 STD: 0.03066841221770144

a=0.7, b=0.6, c=0.1: Avg Accuracy: 58.84898326285186, Accuracy STD: 2.9718020442268114, Avg F1: 0.42898880878512724, F1 STD: 0.02500309725514113

a=0.7, b=0.6, c=0.2: Avg Accuracy: 57.601771099636906, Accuracy STD: 2.7549705557182285, Avg F1: 0.4108542367874283, F1 STD: 0.02352271460922359

a=0.7, b=0.6, c=0.3: Avg Accuracy: 56.925107755171396, Accuracy STD: 2.998083601163314, Avg F1: 0.4074811673393651, F1 STD: 0.025984319346475587

a=0.7, b=0.6, c=0.4: Avg Accuracy: 56.932137588678245, Accuracy STD: 3.5819745251427286, Avg F1: 0.40114203264633996, F1 STD: 0.021060931255464294

a=0.7, b=0.8, c=0.1: Avg Accuracy: 58.002348697178114, Accuracy STD: 3.024455436807054, Avg F1: 0.41858559638565895, F1 STD: 0.023585827621155683

a=0.7, b=0.8, c=0.2: Avg Accuracy: 58.1617176558589, Accuracy STD: 4.294381657058781, Avg F1: 0.417194466153456, F1 STD: 0.030217697387938513

a=0.7, b=0.8, c=0.3: Avg Accuracy: 57.0219933992927, Accuracy STD: 3.521806469065303, Avg F1: 0.4083476436500413, F1 STD: 0.025128321303625918

a=0.7, b=0.8, c=0.4: Avg Accuracy: 57.35979406336931, Accuracy STD: 3.7193638679332786, Avg F1: 0.4103753111677402, F1 STD: 0.028565081165835927

a=0.7, b=0.8, c=0.5: Avg Accuracy: 56.976633995111555, Accuracy STD: 2.943476970129667, Avg F1: 0.40340631745397204, F1 STD: 0.024316092001376802

a=0.7, b=0.9, c=0.1: Avg Accuracy: 58.2497106289844, Accuracy STD: 2.7925666633587545, Avg F1: 0.41839865114514546, F1 STD: 0.03480151938655596

a=0.7, b=0.9, c=0.2: Avg Accuracy: 58.319351488338086, Accuracy STD: 3.6289247736957035, Avg F1: 0.41984133142508445, F1 STD: 0.03164692064915879

a=0.7, b=0.9, c=0.3: Avg Accuracy: 57.41050001975484, Accuracy STD: 2.9338080782988403, Avg F1: 0.40810145337262316, F1 STD: 0.024573990910454552

a=0.7, b=0.9, c=0.4: Avg Accuracy: 57.11828340279085, Accuracy STD: 3.07474946767724, Avg F1: 0.4124255088344082, F1 STD: 0.021558897626425697

a=0.7, b=0.9, c=0.5: Avg Accuracy: 56.81582109872001, Accuracy STD: 3.2857412759992535, Avg F1: 0.40285411339814436, F1 STD: 0.03148552216465776

a=0.7, b=1, c=0.1: Avg Accuracy: 58.406007417741876, Accuracy STD: 3.353456733739938, Avg F1: 0.4191761233803426, F1 STD: 0.022982727581235547

a=0.7, b=1, c=0.2: Avg Accuracy: 58.242560068029334, Accuracy STD: 3.316830595761772, Avg F1: 0.4123322389184561, F1 STD: 0.034983812579582815

a=0.7, b=1, c=0.3: Avg Accuracy: 57.20601161308332, Accuracy STD: 3.5521612843925396, Avg F1: 0.40584774040222965, F1 STD: 0.027391884085112984

a=0.7, b=1, c=0.4: Avg Accuracy: 56.95533219316623, Accuracy STD: 3.443618082066778, Avg F1: 0.4097927354245174, F1 STD: 0.02748416310899763

a=0.7, b=1, c=0.5: Avg Accuracy: 56.82295351530321, Accuracy STD: 3.140919103700741, Avg F1: 0.40448887506986475, F1 STD: 0.03055857952581195

## Hyperparamters Grid search

weight_decay=0.0001, epochs=100, corr_epochs=50, batch_size=32, scheduler_steps=32, coscheduler_steps=64 =>

Avg Accuracy: 62.675818145100024, Accuracy STD: 3.4597881333995306, Avg F1: 0.4120343894210808, F1 STD: 0.03359170631078166

### Code and Results

In [None]:
import os
import pandas as pd
import subprocess
from itertools import product

# Define the parameter grid
param_grid = {
    'weight_decay': [1e-4, 1e-3],
    'epochs': [50, 100],
    'corr_epochs': [50, 100],
    'batch_size': [32, 64],
    'scheduler_steps': [32, 64],
    'coscheduler_steps': [32, 64]
}

# Create combinations of parameters
param_combinations = list(product(param_grid['weight_decay'], param_grid['epochs'], param_grid['corr_epochs'], param_grid['batch_size'], param_grid['scheduler_steps'], param_grid['coscheduler_steps']))

for weight_decay, epochs, corr_epochs, batch_size, scheduler_steps, coscheduler_steps in param_combinations:
    # Read the hparams file
    with open('/kaggle/working/SRP_Domain_Adaptation/Raincoat/configs/hparams.py', 'r') as file:
        data = file.readlines()

    # Modify the file with the current parameters
    for i, line in enumerate(data):
        if 'num_epochs' in line:
            data[i] = f'                \'num_epochs\': {epochs},\n'
        if 'batch_size' in line:
            data[i] = f'                \'batch_size\': {batch_size},\n'
        if 'weight_decay' in line:
            data[i] = f'                \'weight_decay\': {weight_decay},\n'
        if 'corr_epochs' in line:
            data[i] = f'                \'corr_epochs\': {corr_epochs},\n'
        if 'scheduler_steps' in line:
            data[i] = f'                \'scheduler_steps\': {scheduler_steps},\n'
        if 'coscheduler_steps' in line:
            data[i] = f'                \'coscheduler_steps\': {coscheduler_steps},\n'

    with open('/kaggle/working/SRP_Domain_Adaptation/Raincoat/configs/hparams.py', 'w') as file:
        file.writelines(data)

    # Run the script
    result = subprocess.run(['python3', '/kaggle/working/SRP_Domain_Adaptation/Raincoat/main.py', '--experiment_description', 'WISDM', '--dataset', 'WISDM', '--num_runs', '5'], capture_output=True, text=True)

    # Check if the script ran successfully
    if result.returncode != 0:
        print(f"Error running script for weight_decay={weight_decay}, epochs={epochs}, corr_epochs={corr_epochs}, batch_size={batch_size}, scheduler_steps={scheduler_steps}, coscheduler_steps={coscheduler_steps}: {result.stderr}")
        continue

    # Read the results
    file_path = '/kaggle/working/SRP_Domain_Adaptation/experiments_logs/RAINCOAT ClosedSet/WISDM/average_correct.csv'
    df = pd.read_csv(file_path)
    last_row = df.tail(1).values[0]

    avg_accuracy = last_row[0]
    accuracy_std = last_row[1]
    avg_f1 = last_row[2]
    f1_std = last_row[3]

    print(f"weight_decay={weight_decay}, epochs={epochs}, corr_epochs={corr_epochs}, batch_size={batch_size}, scheduler_steps={scheduler_steps}, coscheduler_steps={coscheduler_steps},{avg_accuracy}, {accuracy_std}, {avg_f1}, {f1_std}")

weight_decay=0.0001, epochs=50, corr_epochs=50, batch_size=32, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 60.845348374127035, Accuracy STD: 1.7344759638339933, Avg F1: 0.3722182080349451, F1 STD: 0.014925648455502918

weight_decay=0.0001, epochs=50, corr_epochs=50, batch_size=32, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 60.90828446607776, Accuracy STD: 1.782216296059289, Avg F1: 0.3739647538465996, F1 STD: 0.01321645067382626

weight_decay=0.0001, epochs=50, corr_epochs=50, batch_size=32, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 61.484525674517975, Accuracy STD: 2.4384478405560603, Avg F1: 0.37549753871975, F1 STD: 0.022606919863084322

weight_decay=0.0001, epochs=50, corr_epochs=50, batch_size=32, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 61.55009070650092, Accuracy STD: 2.371484709069984, Avg F1: 0.37503651692252793, F1 STD: 0.02207972851173559

weight_decay=0.0001, epochs=50, corr_epochs=50, batch_size=64, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 62.4578234240365, Accuracy STD: 1.0565785315875675, Avg F1: 0.32716888160399427, F1 STD: 0.020410215426255915

weight_decay=0.0001, epochs=50, corr_epochs=50, batch_size=64, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 62.52091333958542, Accuracy STD: 1.0402562921934642, Avg F1: 0.32767428325305453, F1 STD: 0.01993975698316996

weight_decay=0.0001, epochs=50, corr_epochs=50, batch_size=64, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 61.49227597076081, Accuracy STD: 1.9678950451139647, Avg F1: 0.32530500731991163, F1 STD: 0.016483293788097687

weight_decay=0.0001, epochs=50, corr_epochs=50, batch_size=64, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 61.49475982570368, Accuracy STD: 1.92830346292595, Avg F1: 0.32497185032345544, F1 STD: 0.016014365959256094

weight_decay=0.0001, epochs=50, corr_epochs=100, batch_size=32, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 58.158873114901986, Accuracy STD: 1.8568497931957322, Avg F1: 0.3620688146354969, F1 STD: 0.01689703548620068

weight_decay=0.0001, epochs=50, corr_epochs=100, batch_size=32, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 58.12815559442521, Accuracy STD: 1.8557732310287787, Avg F1: 0.36257217372425354, F1 STD: 0.016137901265122958

weight_decay=0.0001, epochs=50, corr_epochs=100, batch_size=32, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 58.415951644893674, Accuracy STD: 2.0281835764732126, Avg F1: 0.36600265545268856, F1 STD: 0.018554038200560986

weight_decay=0.0001, epochs=50, corr_epochs=100, batch_size=32, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 58.68886170643397, Accuracy STD: 1.767760813408885, Avg F1: 0.36871551659104856, F1 STD: 0.016857423270283087

weight_decay=0.0001, epochs=50, corr_epochs=100, batch_size=64, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 61.72094827935806, Accuracy STD: 0.6948181206048507, Avg F1: 0.34512401773073487, F1 STD: 0.020660005338472422

weight_decay=0.0001, epochs=50, corr_epochs=100, batch_size=64, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 61.77103606296991, Accuracy STD: 0.4583626005655768, Avg F1: 0.34534572117047324, F1 STD: 0.01866011041620703

weight_decay=0.0001, epochs=50, corr_epochs=100, batch_size=64, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 59.310565726364175, Accuracy STD: 2.749848640869131, Avg F1: 0.3187282355134634, F1 STD: 0.023327278695258358

weight_decay=0.0001, epochs=50, corr_epochs=100, batch_size=64, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 59.690002297905494, Accuracy STD: 2.667838620402964, Avg F1: 0.3229778501606616, F1 STD: 0.026390324764531844

weight_decay=0.0001, epochs=100, corr_epochs=50, batch_size=32, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 62.644278933238105, Accuracy STD: 3.445644986538159, Avg F1: 0.4114964752077094, F1 STD: 0.03376479024223431

weight_decay=0.0001, epochs=100, corr_epochs=50, batch_size=32, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 62.675818145100024, Accuracy STD: 3.4597881333995306, Avg F1: 0.4120343894210808, F1 STD: 0.03359170631078166

weight_decay=0.0001, epochs=100, corr_epochs=50, batch_size=32, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 61.63162509503513, Accuracy STD: 2.9014120948260387, Avg F1: 0.4083369819925228, F1 STD: 0.03653952560398736

weight_decay=0.0001, epochs=100, corr_epochs=50, batch_size=32, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 61.712107588325594, Accuracy STD: 2.868859521192415, Avg F1: 0.40880076276156185, F1 STD: 0.036637516592270866

weight_decay=0.0001, epochs=100, corr_epochs=50, batch_size=64, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 63.990736382727064, Accuracy STD: 1.8721446247041265, Avg F1: 0.3990160947832722, F1 STD: 0.010000799153395239

weight_decay=0.0001, epochs=100, corr_epochs=50, batch_size=64, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 64.10875966996494, Accuracy STD: 1.9062108315822681, Avg F1: 0.4002483154213853, F1 STD: 0.011949792969415028

weight_decay=0.0001, epochs=100, corr_epochs=50, batch_size=64, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 65.02366835332393, Accuracy STD: 1.6742590673094733, Avg F1: 0.40779422432381673, F1 STD: 0.026729941870148995

weight_decay=0.0001, epochs=100, corr_epochs=50, batch_size=64, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 65.01434434399991, Accuracy STD: 1.848297029750803, Avg F1: 0.4080564784187527, F1 STD: 0.02761564739327226

weight_decay=0.0001, epochs=100, corr_epochs=100, batch_size=32, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 59.56406455849468, Accuracy STD: 2.7124367871305375, Avg F1: 0.39792507338417993, F1 STD: 0.02946070227285435

weight_decay=0.0001, epochs=100, corr_epochs=100, batch_size=32, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 59.18104180635904, Accuracy STD: 2.772230888041082, Avg F1: 0.3962387722862205, F1 STD: 0.02744356904320329

weight_decay=0.0001, epochs=100, corr_epochs=100, batch_size=32, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 57.87687538336424, Accuracy STD: 1.7515607368264274, Avg F1: 0.38754973113323277, F1 STD: 0.014768180844520513

weight_decay=0.0001, epochs=100, corr_epochs=100, batch_size=32, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 58.165581560691635, Accuracy STD: 2.1782691569431485, Avg F1: 0.3923920294290235, F1 STD: 0.02397463146129952

weight_decay=0.0001, epochs=100, corr_epochs=100, batch_size=64, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 62.41918250729155, Accuracy STD: 3.1837132427642616, Avg F1: 0.3874666680857143, F1 STD: 0.026532683800023816

weight_decay=0.0001, epochs=100, corr_epochs=100, batch_size=64, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 62.74284872598565, Accuracy STD: 3.1631133787376977, Avg F1: 0.3920695094753827, F1 STD: 0.027617653453747223

weight_decay=0.0001, epochs=100, corr_epochs=100, batch_size=64, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 62.850997785363084, Accuracy STD: 3.0987561438817766, Avg F1: 0.38738205503445416, F1 STD: 0.025008460468795807

weight_decay=0.0001, epochs=100, corr_epochs=100, batch_size=64, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 63.12743635608325, Accuracy STD: 3.108823060437285, Avg F1: 0.39088244125833527, F1 STD: 0.029820498373596186

weight_decay=0.001, epochs=50, corr_epochs=50, batch_size=32, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 61.927643086107636, Accuracy STD: 2.1403723325393673, Avg F1: 0.3860853212193859, F1 STD: 0.01783798839344192

weight_decay=0.001, epochs=50, corr_epochs=50, batch_size=32, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 62.06005793494039, Accuracy STD: 2.202209395552627, Avg F1: 0.3877923447928038, F1 STD: 0.02072743013572295

weight_decay=0.001, epochs=50, corr_epochs=50, batch_size=32, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 60.29412828920391, Accuracy STD: 2.393362340330246, Avg F1: 0.3682487052301188, F1 STD: 0.02476484663236907

weight_decay=0.001, epochs=50, corr_epochs=50, batch_size=32, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 60.08404984006212, Accuracy STD: 2.287130666929884, Avg F1: 0.3643903767620673, F1 STD: 0.026637469289113206

weight_decay=0.001, epochs=50, corr_epochs=50, batch_size=64, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 63.158090545575284, Accuracy STD: 2.2355859683897443, Avg F1: 0.3445950626695686, F1 STD: 0.021409042659625958

weight_decay=0.001, epochs=50, corr_epochs=50, batch_size=64, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 63.13693029901035, Accuracy STD: 2.203123940702611, Avg F1: 0.34424293753040874, F1 STD: 0.021151195541013853

weight_decay=0.001, epochs=50, corr_epochs=50, batch_size=64, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 62.590451916332334, Accuracy STD: 1.4724353769009582, Avg F1: 0.334301150574036, F1 STD: 0.019903044550956962

weight_decay=0.001, epochs=50, corr_epochs=50, batch_size=64, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 62.63880656714601, Accuracy STD: 1.4227538739688088, Avg F1: 0.33472094290857013, F1 STD: 0.01980995174425126

weight_decay=0.001, epochs=50, corr_epochs=100, batch_size=32, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 59.158827511277266, Accuracy STD: 2.4805820109961085, Avg F1: 0.37698749688176314, F1 STD: 0.014537584767592801

weight_decay=0.001, epochs=50, corr_epochs=100, batch_size=32, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 59.350572374358066, Accuracy STD: 2.2093895443516187, Avg F1: 0.37940105095842785, F1 STD: 0.01342122137861234

weight_decay=0.001, epochs=50, corr_epochs=100, batch_size=32, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 57.52259084182342, Accuracy STD: 2.5825125851160506, Avg F1: 0.35569724666965724, F1 STD: 0.018105329717540194

weight_decay=0.001, epochs=50, corr_epochs=100, batch_size=32, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 57.61233178489075, Accuracy STD: 2.8428522602414823, Avg F1: 0.3559644133938414, F1 STD: 0.021531856174499927

weight_decay=0.001, epochs=50, corr_epochs=100, batch_size=64, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 61.078970280470955, Accuracy STD: 2.1174398514332373, Avg F1: 0.34113002286308114, F1 STD: 0.015990584858768732

weight_decay=0.001, epochs=50, corr_epochs=100, batch_size=64, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 61.42936467012582, Accuracy STD: 2.3363813634391644, Avg F1: 0.34263786272031904, F1 STD: 0.015909409088659458

weight_decay=0.001, epochs=50, corr_epochs=100, batch_size=64, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 61.73841217849041, Accuracy STD: 1.268044151911372, Avg F1: 0.34238830443618, F1 STD: 0.01593718457984223

weight_decay=0.001, epochs=50, corr_epochs=100, batch_size=64, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 61.87112221170814, Accuracy STD: 1.6112977533611061, Avg F1: 0.34185537167881586, F1 STD: 0.01723194079700186

weight_decay=0.001, epochs=100, corr_epochs=50, batch_size=32, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 61.31583169247496, Accuracy STD: 2.8552565449560166, Avg F1: 0.40184645141314856, F1 STD: 0.03213451215830871

weight_decay=0.001, epochs=100, corr_epochs=50, batch_size=32, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 61.554015772076944, Accuracy STD: 2.9149876560214003, Avg F1: 0.4036582821408209, F1 STD: 0.03324300733296497

weight_decay=0.001, epochs=100, corr_epochs=50, batch_size=32, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 61.42353783447241, Accuracy STD: 1.9779130350647631, Avg F1: 0.40088665172076077, F1 STD: 0.02641754209607126

weight_decay=0.001, epochs=100, corr_epochs=50, batch_size=32, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 61.36143327236785, Accuracy STD: 2.0341643793343196, Avg F1: 0.3997248703890241, F1 STD: 0.025674819924756717

weight_decay=0.001, epochs=100, corr_epochs=50, batch_size=64, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 63.90275254044642, Accuracy STD: 3.203674846196395, Avg F1: 0.38967784691552104, F1 STD: 0.02763127100004394

weight_decay=0.001, epochs=100, corr_epochs=50, batch_size=64, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 64.01584777854166, Accuracy STD: 3.0436812730675507, Avg F1: 0.3909096236797195, F1 STD: 0.025250190664625586

weight_decay=0.001, epochs=100, corr_epochs=50, batch_size=64, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 64.62285140636818, Accuracy STD: 1.6255677049300092, Avg F1: 0.40196636838861083, F1 STD: 0.010055778686583774

weight_decay=0.001, epochs=100, corr_epochs=50, batch_size=64, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 64.70383106269841, Accuracy STD: 1.6580218898379568, Avg F1: 0.402331351085954, F1 STD: 0.009817898123462086

weight_decay=0.001, epochs=100, corr_epochs=100, batch_size=32, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 58.48539836805246, Accuracy STD: 2.047523990637763, Avg F1: 0.38914239515712445, F1 STD: 0.02018133691743511

weight_decay=0.001, epochs=100, corr_epochs=100, batch_size=32, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 58.30584160640613, Accuracy STD: 1.8552318250554916, Avg F1: 0.38748068025588717, F1 STD: 0.022528633447513113

weight_decay=0.001, epochs=100, corr_epochs=100, batch_size=32, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 57.41842824953871, Accuracy STD: 2.288472766277704, Avg F1: 0.37965639810590346, F1 STD: 0.018306719156939898

weight_decay=0.001, epochs=100, corr_epochs=100, batch_size=32, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 57.34478688161572, Accuracy STD: 2.2507432906911506, Avg F1: 0.3800309924898363, F1 STD: 0.01815741228804497

weight_decay=0.001, epochs=100, corr_epochs=100, batch_size=64, scheduler_steps=32, coscheduler_steps=32,Avg Accuracy: 62.3251452080323, Accuracy STD: 4.078796956845765, Avg F1: 0.3836007016142146, F1 STD: 0.02829732937872749

weight_decay=0.001, epochs=100, corr_epochs=100, batch_size=64, scheduler_steps=32, coscheduler_steps=64,Avg Accuracy: 62.38561844426081, Accuracy STD: 4.0125240838495015, Avg F1: 0.38438157378534216, F1 STD: 0.028871898898666752

weight_decay=0.001, epochs=100, corr_epochs=100, batch_size=64, scheduler_steps=64, coscheduler_steps=32,Avg Accuracy: 62.47270567137652, Accuracy STD: 3.3822830453294443, Avg F1: 0.38050247036097035, F1 STD: 0.014787200689262358

weight_decay=0.001, epochs=100, corr_epochs=100, batch_size=64, scheduler_steps=64, coscheduler_steps=64,Avg Accuracy: 62.67624963984586, Accuracy STD: 3.2504900248872364, Avg F1: 0.38199190105444575, F1 STD: 0.015095463580518784

## a value Grid Search 

because FrFT(FrFT(x, a), -a) = x, and best a value is 0.4

### Code and results

In [None]:
import os
import pandas as pd
import subprocess
from itertools import product

a_values = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]

for a in a_values:
    # Read the hparams file
    with open('/kaggle/working/SRP_Domain_Adaptation/Raincoat/configs/data_model_configs.py', 'r') as file:
        data = file.readlines()

    # Modify the file with the current parameters
    for i, line in enumerate(data):
        if 'self.fraction_order' in line:
            data[i] = f'        self.fraction_order = {a}\n'

    with open('/kaggle/working/SRP_Domain_Adaptation/Raincoat/configs/data_model_configs.py', 'w') as file:
        file.writelines(data)

    # Run the script
    result = subprocess.run(['python3', '/kaggle/working/SRP_Domain_Adaptation/Raincoat/main.py', '--experiment_description', 'WISDM', '--dataset', 'WISDM', '--num_runs', '5', '--device', 'cuda'], capture_output=True, text=True)

    # Check if the script ran successfully
    if result.returncode != 0:
        print(f"Error running script for a = {a}: {result.stderr}")
        continue

    # Read the results
    file_path = '/kaggle/working/SRP_Domain_Adaptation/experiments_logs/RAINCOAT ClosedSet/WISDM/average_correct.csv'
    df = pd.read_csv(file_path)
    last_row = df.tail(1).values[0]

    avg_accuracy = last_row[0]
    accuracy_std = last_row[1]
    avg_f1 = last_row[2]
    f1_std = last_row[3]

    print(f"a = {a}: Avg Accuracy: {avg_accuracy}, Accuracy STD: {accuracy_std}, Avg F1: {avg_f1}, F1 STD: {f1_std}")

a = 0: Avg Accuracy: Avg Accuracy: 61.83598440013138, Accuracy STD: Accuracy STD: 2.6520156577608174, Avg F1: Avg F1: 0.3992052846994296, F1 STD: F1 STD: 0.017602405809568684

a = 0.1: Avg Accuracy: Avg Accuracy: 60.32036157784269, Accuracy STD: Accuracy STD: 2.3114505032113164, Avg F1: Avg F1: 0.3884078172342168, F1 STD: F1 STD: 0.01577545739019111

a = 0.2: Avg Accuracy: Avg Accuracy: 61.680044593674836, Accuracy STD: Accuracy STD: 1.7654643101976841, Avg F1: Avg F1: 0.38770826932306407, F1 STD: F1 STD: 0.011529695452993352

a = 0.3: Avg Accuracy: Avg Accuracy: 61.20645447104728, Accuracy STD: Accuracy STD: 1.697975524541123, Avg F1: Avg F1: 0.38847938362908707, F1 STD: F1 STD: 0.022309338814627187

a = 0.4: Avg Accuracy: Avg Accuracy: 62.644278933238105, Accuracy STD: Accuracy STD: 3.445644986538159, Avg F1: Avg F1: 0.4114964752077094, F1 STD: F1 STD: 0.03376479024223431

a = 0.5: Avg Accuracy: Avg Accuracy: 62.01382502989119, Accuracy STD: Accuracy STD: 2.434137036311512, Avg F1: Avg F1: 0.3938725945468643, F1 STD: F1 STD: 0.0269822282249586

a = 0.6: Avg Accuracy: Avg Accuracy: 62.26678395624181, Accuracy STD: Accuracy STD: 2.668188752684112, Avg F1: Avg F1: 0.395057418136893, F1 STD: F1 STD: 0.02169692153212977

a = 0.7: Avg Accuracy: Avg Accuracy: 60.6739319127117, Accuracy STD: Accuracy STD: 2.773740379945678, Avg F1: Avg F1: 0.3908509479637413, F1 STD: F1 STD: 0.035546245784000466

a = 0.8: Avg Accuracy: Avg Accuracy: 61.434794413386406, Accuracy STD: Accuracy STD: 3.035998227735259, Avg F1: Avg F1: 0.3882534595159481, F1 STD: F1 STD: 0.02036220599321866

a = 0.9: Avg Accuracy: Avg Accuracy: 61.09161064408859, Accuracy STD: Accuracy STD: 3.790634018309286, Avg F1: Avg F1: 0.39966012671382545, F1 STD: F1 STD: 0.026671688444261918

a = 1: Avg Accuracy: Avg Accuracy: 62.07641722507568, Accuracy STD: Accuracy STD: 1.7902636447653781, Avg F1: Avg F1: 0.4082623282213465, F1 STD: F1 STD: 0.012961571512324946

# Run and check results

In [23]:
# remove old logs
!rm -rf experiments_logs

In [None]:
!python3 /content/SRP_Domain_Adaptation/main.py --experiment_description WISDM --dataset WISDM --num_runs 1

In [4]:
import pandas as pd

file_path = 'experiments_logs/RAINCOAT ClosedSet/WISDM/average_correct.csv'

# Read the CSV file
df = pd.read_csv(file_path)

# Extract the last row
last_row = df.tail(1).values[0]

# Extract the first four columns
avg_accuracy = last_row[0]
accuracy_std = last_row[1]
avg_f1 = last_row[2]
f1_std = last_row[3]

# Print the values side by side
print(f"{avg_accuracy}, {accuracy_std}, {avg_f1}, {f1_std}")


Avg Accuracy: 63.11333738107413, Accuracy STD: 2.699980524318425, Avg F1: 0.47183228003868677, F1 STD: 0.01960662122499978


# UNIDA Run

In [None]:
!pip install diptest

Collecting diptest
  Downloading diptest-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.0 kB)
Downloading diptest-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (195 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/195.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m195.7/195.7 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: diptest
Successfully installed diptest-0.8.0


In [None]:
# remove old logs
!rm -rf experiments_logs

In [None]:
!python main_uni.py --experiment_description WISDM --dataset WISDM --num_runs 5

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
best f1: 0.47205387205387206
[Epoch : 25/50]
best f1: 0.48033225684443276
[Epoch : 27/50]
best f1: 0.5082934609250399
[Epoch : 40/50]
best f1: 0.5532495256166983
===== Correct ====
[10000000000.0, 10000000000.0, 10000000000.0, 10000000000.0, 10000000000.0, array([0.11599505])]
Dataset: WISDM
Method:  RAINCOAT
Source: 16 ---> Target: 4
Run ID: 1
[Epoch : 1/50]
best f1: 0.20472120472120472
[Epoch : 2/50]
best f1: 0.2534154844499672
[Epoch : 3/50]
best f1: 0.30009620009620014
[Epoch : 4/50]
best f1: 0.4028138528138528
[Epoch : 5/50]
best f1: 0.4411645364171231
[Epoch : 6/50]
best f1: 0.47614496216646757
[Epoch : 7/50]
best f1: 0.48282828282828283
[Epoch : 10/50]
best f1: 0.49968645677714446
[Epoch : 11/50]
best f1: 0.520230607966457
[Epoch : 12/50]
best f1: 0.5271392081736909
[Epoch : 13/50]
best f1: 0.5453805453805454
[Epoch : 18/50]
best f1: 0.5563492063492064
[Epoch : 19/50]
best f1: 0.5640223747775235
===== Correct ====


In [None]:
import pandas as pd

file_path = '/content/SRP_Domain_Adaptation/experiments_logs/WISDM-RAINCOAT-uni/WISDM-RAINCOAT-uni/average_correct.csv'

df = pd.read_csv(file_path)

#print(df)

# print last row (avg)
print(df.tail(1))
