In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
import warnings
import copy
import os
warnings.filterwarnings('ignore')

In [2]:
for subid1 in range(10):
    for subid2 in range(10):
        if subid1 != subid2:
            path = f'weights_steps/Sub{subid1+1}ToSub{subid2+1}_AllChannels'
            if not os.path.exists(path):
                os.makedirs(path)

In [3]:
def show(condition, fake, real):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    condition = condition.detach().cpu().numpy()
    fake = fake.detach().cpu().numpy()
    real = real.detach().cpu().numpy()
    t = np.arange(0, 200, 1)
    fig, axs = plt.subplots(2, 1)
    axs[0].plot(t, condition[0, 0])
    
    axs[1].plot(t, real[0, 0], label='real')
    axs[1].plot(t, fake[0, 0], label='fake')
    axs[1].legend()
    plt.show()

2D U-Net

In [4]:
def crop(data, new_shape):
    '''
    Function for cropping an image tensor: Given an image tensor and the new shape,
    crops to the center pixels.
    Parameters:
        image: image tensor of shape (batch size, channels, height, width)
        new_shape: a torch.Size object with the shape you want x to have
    '''
    middle_length = data.shape[2] // 2
    starting = middle_length - round(new_shape[2] / 2)
    final = starting + new_shape[2]
    cropped_data = data[:, :, starting:final]
    return cropped_data

class ContractingBlock(nn.Module):
    '''
    ContractingBlock Class
    Performs two convolutions followed by a max pool operation.
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels, use_dropout=False, use_bn=True):
        super(ContractingBlock, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, input_channels * 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(input_channels * 2, input_channels * 2, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        if use_bn:
            self.batchnorm = nn.BatchNorm1d(input_channels * 2)
        self.use_bn = use_bn
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout

    def forward(self, x):
        '''
        Function for completing a forward pass of ContractingBlock: 
        Given an image tensor, completes a contracting block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv1(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x

class ExpandingBlock(nn.Module):
    '''
    ExpandingBlock Class:
    Performs an upsampling, a convolution, a concatenation of its two inputs,
    followed by two more convolutions with optional dropout
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels, use_dropout=False, use_bn=True):
        super(ExpandingBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv1 = nn.Conv1d(input_channels, input_channels // 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(input_channels // 2, input_channels // 2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(input_channels // 2, input_channels // 2, kernel_size=3, padding=1)
        if use_bn:
            self.batchnorm = nn.BatchNorm1d(input_channels // 2)
        self.use_bn = use_bn
        self.activation = nn.ReLU()
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout

    def forward(self, x):
        '''
        Function for completing a forward pass of ExpandingBlock: 
        Given an image tensor, completes an expanding block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
            skip_con_x: the image tensor from the contracting path (from the opposing block of x)
                    for the skip connection
        '''
        x = self.upsample(x)
        x = self.conv1(x)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv3(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x


class FeatureMapBlock0(nn.Module):
    '''
    FeatureMapBlock Class
    The first layers of a U-Net - 
    maps each pixel to a pixel with the correct number of output dimensions
    using a 1x1 convolution.
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock0, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(output_channels, output_channels, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.batchnorm = nn.BatchNorm1d(output_channels)

    def forward(self, x):
        '''
        Function for completing a forward pass of FeatureMapBlock: 
        Given an image tensor, returns it mapped to the desired number of channels.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv1(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        return x   
    
class FeatureMapBlock(nn.Module):
    '''
    FeatureMapBlock Class
    The final layer of a U-Net - 
    maps each pixel to a pixel with the correct number of output dimensions
    using a 1x1 convolution.
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv1d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):
        '''
        Function for completing a forward pass of FeatureMapBlock: 
        Given an image tensor, returns it mapped to the desired number of channels.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv(x)
        return x

class UNet(nn.Module):
    '''
    UNet Class
    A series of 4 contracting blocks followed by 4 expanding blocks to 
    transform an input image into the corresponding paired image, with an upfeature
    layer at the start and a downfeature layer at the end.
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels, hidden_channels=64):
        super(UNet, self).__init__()
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.contract1 = FeatureMapBlock0(input_channels, hidden_channels)
        self.contract2 = ContractingBlock(hidden_channels, use_dropout=True)
        self.contract3 = ContractingBlock(hidden_channels * 2, use_dropout=True)
        self.contract4 = ContractingBlock(hidden_channels * 4)
        self.expand1 = ExpandingBlock(hidden_channels * 8)
        self.expand2 = ExpandingBlock(hidden_channels * 4)
        self.expand3 = ExpandingBlock(hidden_channels * 2)
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        '''
        Function for completing a forward pass of UNet: 
        Given an image tensor, passes it through U-Net and returns the output.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x1 = self.contract1(x)
        x2 = self.maxpool(x1)
        x3 = self.contract2(x2)
        x4 = self.maxpool(x3)
        x5 = self.contract3(x4)
        x6 = self.maxpool(x5)
        x7 = self.contract4(x6)
        x8 = self.expand1(x7)
        x9 = self.expand2(x8)
        x10 = self.expand3(x9)
        xn = self.downfeature(x10)
        #return self.sigmoid(xn)
        return xn

In [5]:
import torch.nn.functional as F
# New parameters
recon_criterion1 = nn.CosineEmbeddingLoss(0.5)
recon_criterion2 = nn.MSELoss()
lambda_recon1 = 1
lambda_recon2 = 1

n_epochs = 50
input_dim = 17
real_dim = 17
display_step = 1000
batch_size = 32
lr = 0.002
#target_shape = 100
device = 'cuda'

import data

In [6]:
class GetData(torch.utils.data.Dataset):
    
    def __init__(self, data1, data2):
        self.data1 = data1
        self.data2 = data2
    
    def __getitem__(self, index):
        data1 = self.data1[index]
        data2 = self.data2[index]
        return data1, data2
    
    def __len__(self):
        return len(self.data1)

In [7]:
gen = UNet(input_dim, real_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm1d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)

In [8]:
# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CLASS: get_gen_loss
def get_gen_loss(gen, real, condition, recon_criterion1, recon_criterion2, lambda_recon1, lambda_recon2):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator; takes the condition and returns potential images
        disc: the discriminator; takes images and the condition and
          returns real/fake prediction matrices
        real: the real images (e.g. maps) to be used to evaluate the reconstruction
        condition: the source images (e.g. satellite imagery) which are used to produce the real images
        adv_criterion: the adversarial loss function; takes the discriminator 
                  predictions and the true labels and returns a adversarial 
                  loss (which you aim to minimize)
        recon_criterion: the reconstruction loss function; takes the generator 
                    outputs and the real images and returns a reconstructuion 
                    loss (which you aim to minimize)
        lambda_recon: the degree to which the reconstruction loss should be weighted in the sum
    '''
    # Steps: 1) Generate the fake images, based on the conditions.
    #        2) Evaluate the fake images and the condition with the discriminator.
    #        3) Calculate the adversarial and reconstruction losses.
    #        4) Add the two losses, weighting the reconstruction loss appropriately.
    fake = gen(condition)
    gen_rec_loss1 = recon_criterion1(real.flatten(start_dim=1).to(device), fake.flatten(start_dim=1).to(device), torch.ones(fake.shape[0]).to(device))
    gen_rec_loss2 = recon_criterion2(real, fake)
    gen_loss = lambda_recon1 * gen_rec_loss1 + lambda_recon2 * gen_rec_loss2
    return gen_loss

In [11]:
def train_and_test(data1, data2, input_dim, real_dim, lr, testsub1data, testsub2data, subid1, subid2):
    
    for k in range(10):
        
        gen = UNet(input_dim, real_dim).to(device)
        gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
        gen = gen.apply(weights_init)
        
        index = np.arange(16540)
        np.random.shuffle(index)
        
        for i in range(5):
    
            epochs_loss = np.zeros([n_epochs])
            epochs_corr = np.zeros([n_epochs])
            
            torch_data = GetData(torch.from_numpy(traindata1[index[:3000*(i+1)]]).float(), torch.from_numpy(traindata2[index[:3000*(i+1)]]).float())
    
            dataloader = DataLoader(torch_data, batch_size=batch_size, shuffle=True)
    
            best_corr = 0
    
            gen.train()

            for epoch in range(n_epochs):
                loss = 0
                # Dataloader returns the batches
                cur_step = 0
                for data1, data2 in tqdm(dataloader):
                    condition = data1
                    real = data2
                    cur_batch_size = len(condition)
                    condition = condition.to(device)
                    real = real.to(device)

                    gen_opt.zero_grad()
                    gen_loss = get_gen_loss(gen, real, condition, recon_criterion1, recon_criterion2, lambda_recon1, lambda_recon2)
                    gen_loss.backward() # Update gradients
                    gen_opt.step() # Update optimizer
                    loss += gen_loss.item()
            
                    cur_step += 1
        
                gen.eval()
                gen_opt.zero_grad()
        
                testsub1data = testsub1data.to(device)
                testsub2data = testsub2data.to(device)
        
                testsub2fakedata = gen(testsub1data)
        
                arr1 = testsub2fakedata.detach().cpu().numpy()
                arr2 = testsub2data.detach().cpu().numpy()
        
                corr = spearmanr(arr1.flatten(), arr2.flatten())[0]
            
                # Keep track of the average generator loss
                mean_loss = loss / cur_step
        
                print('Round ' + str(k+1) + ' using ' + str(3000*(i+1)) + ' trials, Sub' + str(subid1+1).zfill(2) + ' -> ' + 'Sub' + str(subid2+1).zfill(2) + ': ' + f"Epoch {epoch+1}: EEG2EEG loss: {mean_loss}, Corr : {corr}")
                #show(condition, fake, real)
                loss = 0
                epochs_loss[epoch] = mean_loss
                epochs_corr[epoch] = corr
                if corr > best_corr:
                    best_corr = corr
                    best_gen = copy.deepcopy(gen.state_dict())
                    best_gen_opt = copy.deepcopy(gen_opt.state_dict())
            torch.save({'gen':  best_gen,
                        'gen_opt': best_gen_opt
                        }, f"weights_steps/Sub{subid1+1}ToSub{subid2+1}_AllChannels/best_model_round{k+1}_{3000*(i+1)}trials.pth")
            np.save(f'weights_steps/Sub{subid1+1}ToSub{subid2+1}_AllChannels/gloss_round{k+1}_{3000*(i+1)}trials.npy', epochs_loss)
            np.save(f'weights_steps/Sub{subid1+1}ToSub{subid2+1}_AllChannels/corr_round{k+1}_{3000*(i+1)}trials.npy', epochs_corr)

for subid1 in range(10):
    for subid2 in range(10):
        if subid1 < subid2 and subid1 != 8 and subid2 != 8:
            traindata1 = np.load('eeg_data/train/sub' + str(subid1+1).zfill(2) + '.npy')
            mean1 = np.average(traindata1)
            std1 = np.std(traindata1)
            traindata1 = (traindata1-mean1)/std1
            traindata1 = np.transpose(traindata1, (1, 0, 2))
            traindata2 = np.load('eeg_data/train/sub' + str(subid2+1).zfill(2) + '.npy')
            mean2 = np.average(traindata2)
            std2 = np.std(traindata2)
            traindata2 = (traindata2-mean2)/std2
            traindata2 = np.transpose(traindata2, (1, 0, 2))
            testsub1data = np.load('eeg_data/test/sub' + str(subid1+1).zfill(2) + '.npy')
            testsub1data = (testsub1data-mean1)/std1
            testsub1data = np.transpose(testsub1data, (1, 0, 2))
            testsub1data = torch.from_numpy(testsub1data).float()
            testsub2data = np.load('eeg_data/test/sub' + str(subid2+1).zfill(2) + '.npy')
            testsub2data = (testsub2data-mean2)/std2
            testsub2data = np.transpose(testsub2data, (1, 0, 2))
            testsub2data = torch.from_numpy(testsub2data).float()
            train_and_test(traindata1, traindata2, input_dim, real_dim, lr, testsub1data, testsub2data, subid1, subid2)

  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 1: EEG2EEG loss: 1.946548801787356, Corr : 0.14550835927804814


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 2: EEG2EEG loss: 1.9068511813244922, Corr : 0.4433443932434831


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 3: EEG2EEG loss: 1.9016354616652145, Corr : 0.45181550673364956


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 4: EEG2EEG loss: 1.8766460989383942, Corr : 0.48322211903515605


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 5: EEG2EEG loss: 1.8717160947779392, Corr : 0.4799893994924209


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 6: EEG2EEG loss: 1.8782009190701423, Corr : 0.3406575960428128


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 7: EEG2EEG loss: 1.874485709565751, Corr : 0.456165411330217


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 8: EEG2EEG loss: 1.8611809801548085, Corr : 0.46305153644285163


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 9: EEG2EEG loss: 1.8675275749348579, Corr : 0.4455747421071027


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 10: EEG2EEG loss: 1.8504713304499363, Corr : 0.48811771413430194


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 11: EEG2EEG loss: 1.84425036831105, Corr : 0.47095442589068903


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 12: EEG2EEG loss: 1.8410664530510599, Corr : 0.4937698167392714


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 13: EEG2EEG loss: 1.8350536214544417, Corr : 0.4970896810528234


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 14: EEG2EEG loss: 1.8354212971443826, Corr : 0.518492576410501


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 15: EEG2EEG loss: 1.834880217592767, Corr : 0.4842965305214711


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 16: EEG2EEG loss: 1.8306815966646721, Corr : 0.5189546276482149


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 17: EEG2EEG loss: 1.8262744926391763, Corr : 0.541923510138602


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 18: EEG2EEG loss: 1.82152813546201, Corr : 0.5225691875078833


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 19: EEG2EEG loss: 1.8279415495852207, Corr : 0.5463757593967269


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 20: EEG2EEG loss: 1.813977558562096, Corr : 0.5313661262423963


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 21: EEG2EEG loss: 1.823643425677685, Corr : 0.4885538456475329


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 22: EEG2EEG loss: 1.8236904068196074, Corr : 0.5461001042573294


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 23: EEG2EEG loss: 1.8166685687734725, Corr : 0.4582074990759273


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 24: EEG2EEG loss: 1.8158998755698508, Corr : 0.5362506551389864


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 25: EEG2EEG loss: 1.8136524139566625, Corr : 0.5276008922212965


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 26: EEG2EEG loss: 1.807682865477623, Corr : 0.5392213798742607


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 27: EEG2EEG loss: 1.8042487654280155, Corr : 0.5192625247065225


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 28: EEG2EEG loss: 1.811260756025923, Corr : 0.5510622551620397


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 29: EEG2EEG loss: 1.8049840838351148, Corr : 0.5421882253305738


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 30: EEG2EEG loss: 1.807293491160616, Corr : 0.5472291368836963


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 31: EEG2EEG loss: 1.805653508673323, Corr : 0.5452969372470003


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 32: EEG2EEG loss: 1.8045536685497203, Corr : 0.5282440521306184


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 33: EEG2EEG loss: 1.8070625345757667, Corr : 0.5095949349381871


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 34: EEG2EEG loss: 1.8111810760295137, Corr : 0.5432127470495238


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 35: EEG2EEG loss: 1.803577353345587, Corr : 0.5419170713862904


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 36: EEG2EEG loss: 1.7980026234971715, Corr : 0.5519172142207677


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 37: EEG2EEG loss: 1.8000884094136826, Corr : 0.5340238225663912


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 38: EEG2EEG loss: 1.799240823755873, Corr : 0.5545808629298455


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 39: EEG2EEG loss: 1.8024154153275997, Corr : 0.550252644837494


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 40: EEG2EEG loss: 1.8075141057055046, Corr : 0.5444928188101039


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 41: EEG2EEG loss: 1.8060393206616665, Corr : 0.5471123285065512


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 42: EEG2EEG loss: 1.8000907339948289, Corr : 0.5460508131114311


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 43: EEG2EEG loss: 1.8006237402875374, Corr : 0.5447700051376335


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 44: EEG2EEG loss: 1.7935952709076253, Corr : 0.5395607187612057


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 45: EEG2EEG loss: 1.8004874739241092, Corr : 0.5369917386214762


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 46: EEG2EEG loss: 1.798470432453967, Corr : 0.5460984289357242


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 47: EEG2EEG loss: 1.7940595682631149, Corr : 0.5356289250358874


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 48: EEG2EEG loss: 1.7984790611774364, Corr : 0.5547861660076252


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 49: EEG2EEG loss: 1.7883642328546403, Corr : 0.5496218063792793


  0%|          | 0/94 [00:00<?, ?it/s]

Round 1 using 3000 trials, Sub01 -> Sub02: Epoch 50: EEG2EEG loss: 1.7896467586781115, Corr : 0.5443164402839503


  0%|          | 0/188 [00:00<?, ?it/s]

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 6.00 GiB total capacity; 1.14 GiB already allocated; 0 bytes free; 1.66 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF