In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import spearmanr
import warnings
import copy
import os
warnings.filterwarnings('ignore')

In [2]:
for subid1 in range(10):
    for subid2 in range(10):
        if subid1 != subid2:
            path = f'weights_nocosineloss/Sub{subid1+1}ToSub{subid2+1}_AllChannels'
            if not os.path.exists(path):
                os.makedirs(path)

In [3]:
def show(condition, fake, real):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    condition = condition.detach().cpu().numpy()
    fake = fake.detach().cpu().numpy()
    real = real.detach().cpu().numpy()
    t = np.arange(0, 200, 1)
    fig, axs = plt.subplots(2, 1)
    axs[0].plot(t, condition[0, 0])
    
    axs[1].plot(t, real[0, 0], label='real')
    axs[1].plot(t, fake[0, 0], label='fake')
    axs[1].legend()
    plt.show()

2D U-Net

In [4]:
def crop(data, new_shape):
    '''
    Function for cropping an image tensor: Given an image tensor and the new shape,
    crops to the center pixels.
    Parameters:
        image: image tensor of shape (batch size, channels, height, width)
        new_shape: a torch.Size object with the shape you want x to have
    '''
    middle_length = data.shape[2] // 2
    starting = middle_length - round(new_shape[2] / 2)
    final = starting + new_shape[2]
    cropped_data = data[:, :, starting:final]
    return cropped_data

class ContractingBlock(nn.Module):
    '''
    ContractingBlock Class
    Performs two convolutions followed by a max pool operation.
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels, use_dropout=False, use_bn=True):
        super(ContractingBlock, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, input_channels * 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(input_channels * 2, input_channels * 2, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        if use_bn:
            self.batchnorm = nn.BatchNorm1d(input_channels * 2)
        self.use_bn = use_bn
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout

    def forward(self, x):
        '''
        Function for completing a forward pass of ContractingBlock: 
        Given an image tensor, completes a contracting block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv1(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x

class ExpandingBlock(nn.Module):
    '''
    ExpandingBlock Class:
    Performs an upsampling, a convolution, a concatenation of its two inputs,
    followed by two more convolutions with optional dropout
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels, use_dropout=False, use_bn=True):
        super(ExpandingBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv1 = nn.Conv1d(input_channels, input_channels // 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(input_channels, input_channels // 2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(input_channels // 2, input_channels // 2, kernel_size=3, padding=1)
        if use_bn:
            self.batchnorm = nn.BatchNorm1d(input_channels // 2)
        self.use_bn = use_bn
        self.activation = nn.ReLU()
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout

    def forward(self, x, skip_con_x):
        '''
        Function for completing a forward pass of ExpandingBlock: 
        Given an image tensor, completes an expanding block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
            skip_con_x: the image tensor from the contracting path (from the opposing block of x)
                    for the skip connection
        '''
        x = self.upsample(x)
        x = self.conv1(x)
        #skip_con_x = crop(skip_con_x, x.shape)
        x = torch.cat([x, skip_con_x], axis=1)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv3(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x


class FeatureMapBlock0(nn.Module):
    '''
    FeatureMapBlock Class
    The first layers of a U-Net - 
    maps each pixel to a pixel with the correct number of output dimensions
    using a 1x1 convolution.
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock0, self).__init__()
        self.conv1 = nn.Conv1d(input_channels, output_channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(output_channels, output_channels, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.batchnorm = nn.BatchNorm1d(output_channels)

    def forward(self, x):
        '''
        Function for completing a forward pass of FeatureMapBlock: 
        Given an image tensor, returns it mapped to the desired number of channels.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv1(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.batchnorm(x)
        x = self.activation(x)
        return x   
    
class FeatureMapBlock(nn.Module):
    '''
    FeatureMapBlock Class
    The final layer of a U-Net - 
    maps each pixel to a pixel with the correct number of output dimensions
    using a 1x1 convolution.
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv1d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):
        '''
        Function for completing a forward pass of FeatureMapBlock: 
        Given an image tensor, returns it mapped to the desired number of channels.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv(x)
        return x

class UNet(nn.Module):
    '''
    UNet Class
    A series of 4 contracting blocks followed by 4 expanding blocks to 
    transform an input image into the corresponding paired image, with an upfeature
    layer at the start and a downfeature layer at the end.
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels, hidden_channels=64):
        super(UNet, self).__init__()
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.contract1 = FeatureMapBlock0(input_channels, hidden_channels)
        self.contract2 = ContractingBlock(hidden_channels, use_dropout=True)
        self.contract3 = ContractingBlock(hidden_channels * 2, use_dropout=True)
        self.contract4 = ContractingBlock(hidden_channels * 4)
        self.expand1 = ExpandingBlock(hidden_channels * 8)
        self.expand2 = ExpandingBlock(hidden_channels * 4)
        self.expand3 = ExpandingBlock(hidden_channels * 2)
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        '''
        Function for completing a forward pass of UNet: 
        Given an image tensor, passes it through U-Net and returns the output.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x1 = self.contract1(x)
        x2 = self.maxpool(x1)
        x3 = self.contract2(x2)
        x4 = self.maxpool(x3)
        x5 = self.contract3(x4)
        x6 = self.maxpool(x5)
        x7 = self.contract4(x6)
        x8 = self.expand1(x7, x5)
        x9 = self.expand2(x8, x3)
        x10 = self.expand3(x9, x1)
        xn = self.downfeature(x10)
        #return self.sigmoid(xn)
        return xn

In [5]:
import torch.nn.functional as F
# New parameters
recon_criterion = nn.MSELoss()

n_epochs = 100
input_dim = 17
real_dim = 17
display_step = 1000
batch_size = 32
lr = 0.002
#target_shape = 100
device = 'cuda'

import data

In [6]:
class GetData(torch.utils.data.Dataset):
    
    def __init__(self, data1, data2):
        self.data1 = data1
        self.data2 = data2
    
    def __getitem__(self, index):
        data1 = self.data1[index]
        data2 = self.data2[index]
        return data1, data2
    
    def __len__(self):
        return len(self.data1)

In [7]:
gen = UNet(input_dim, real_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm1d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen = gen.apply(weights_init)

In [8]:
# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)
# GRADED CLASS: get_gen_loss
def get_gen_loss(gen, real, condition, recon_criterion):
    fake = gen(condition)
    return recon_criterion(real, fake)

In [13]:
def train_and_test(torch_data, input_dim, real_dim, lr, testsub1data, testsub2data, subid1, subid2):

    gen = UNet(input_dim, real_dim).to(device)
    gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
    gen = gen.apply(weights_init)
    
    epochs_loss = np.zeros([n_epochs])
    epochs_corr = np.zeros([n_epochs])
    
    dataloader = DataLoader(torch_data, batch_size=batch_size, shuffle=True)
    
    best_corr = 0
    
    gen.train()

    for epoch in range(n_epochs):
        loss = 0
        # Dataloader returns the batches
        cur_step = 0
        for data1, data2 in tqdm(dataloader):
            condition = data1
            real = data2
            cur_batch_size = len(condition)
            condition = condition.to(device)
            real = real.to(device)

            gen_opt.zero_grad()
            gen_loss = get_gen_loss(gen, real, condition, recon_criterion)
            gen_loss.backward() # Update gradients
            gen_opt.step() # Update optimizer
            loss += gen_loss.item()
            
            cur_step += 1
        
        gen.eval()
        gen_opt.zero_grad()
        
        testsub1data = testsub1data.to(device)
        testsub2data = testsub2data.to(device)
        
        testsub2fakedata = gen(testsub1data)
        
        arr1 = testsub2fakedata.detach().cpu().numpy()
        arr2 = testsub2data.detach().cpu().numpy()
        
        corr = spearmanr(arr1.flatten(), arr2.flatten())[0]
            
        # Keep track of the average generator loss
        mean_loss = loss / cur_step
        
        print('Sub' + str(subid1+1).zfill(2) + ' -> ' + 'Sub' + str(subid2+1).zfill(2) + ': ' + f"Epoch {epoch+1}: EEG2EEG loss: {mean_loss}, Corr : {corr}")
        #show(condition, fake, real)
        loss = 0
        epochs_loss[epoch] = mean_loss
        epochs_corr[epoch] = corr
        if corr > best_corr:
            best_corr = corr
            best_gen = copy.deepcopy(gen.state_dict())
            best_gen_opt = copy.deepcopy(gen_opt.state_dict())
    torch.save({'gen':  gen.state_dict(),
                'gen_opt': gen_opt.state_dict()
                }, f"weights_nocosineloss/Sub{subid1+1}ToSub{subid2+1}_AllChannels/final_model.pth")
    torch.save({'gen':  best_gen,
                'gen_opt': best_gen_opt
                }, f"weights_nocosineloss/Sub{subid1+1}ToSub{subid2+1}_AllChannels/best_model.pth")
    np.save(f'weights_nocosineloss/Sub{subid1+1}ToSub{subid2+1}_AllChannels/gloss.npy', epochs_loss)
    np.save(f'weights_nocosineloss/Sub{subid1+1}ToSub{subid2+1}_AllChannels/corr.npy', epochs_corr)

for subid1 in range(10):
    for subid2 in range(10):
        if subid1 < subid2 and subid1 != 8 and subid2 != 8 and subid1 == 6 and subid2 == 7:
            data1 = np.load('eeg_data/train/sub' + str(subid1+1).zfill(2) + '.npy')
            mean1 = np.average(data1)
            std1 = np.std(data1)
            data1 = (data1-mean1)/std1
            data1 = np.transpose(data1, (1, 0, 2))
            data1 = torch.from_numpy(data1).float()
            data2 = np.load('eeg_data/train/sub' + str(subid2+1).zfill(2) + '.npy')
            mean2 = np.average(data2)
            std2 = np.std(data2)
            data2 = (data2-mean2)/std2
            data2 = np.transpose(data2, (1, 0, 2))
            data2 = torch.from_numpy(data2).float()
            torch_data = GetData(data1, data2)
            testsub1data = np.load('eeg_data/test/sub' + str(subid1+1).zfill(2) + '.npy')
            testsub1data = (testsub1data-mean1)/std1
            testsub1data = np.transpose(testsub1data, (1, 0, 2))
            testsub1data = torch.from_numpy(testsub1data).float()
            testsub2data = np.load('eeg_data/test/sub' + str(subid2+1).zfill(2) + '.npy')
            testsub2data = (testsub2data-mean2)/std2
            testsub2data = np.transpose(testsub2data, (1, 0, 2))
            testsub2data = torch.from_numpy(testsub2data).float()
            train_and_test(torch_data, input_dim, real_dim, lr, testsub1data, testsub2data, subid1, subid2)

  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 1: EEG2EEG loss: 0.9251175586213457, Corr : 0.3559150615325467


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 2: EEG2EEG loss: 0.932612998803764, Corr : 0.6659637427128599


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 3: EEG2EEG loss: 0.9165624638936506, Corr : 0.6755654666620189


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 4: EEG2EEG loss: 0.9132339848309928, Corr : 0.664270608420093


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 5: EEG2EEG loss: 0.9105598714411605, Corr : 0.6748279785411735


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 6: EEG2EEG loss: 0.908169828254434, Corr : 0.6828237234846083


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 7: EEG2EEG loss: 0.9073673391019352, Corr : 0.6867603416550553


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 8: EEG2EEG loss: 0.9058970019250128, Corr : 0.683210064166783


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 9: EEG2EEG loss: 0.9051836473107108, Corr : 0.6919687600432827


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 10: EEG2EEG loss: 0.9047376154930956, Corr : 0.6764126320979821


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 11: EEG2EEG loss: 0.9051853179701062, Corr : 0.6997838945567535


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 12: EEG2EEG loss: 0.9041193611847824, Corr : 0.6992774762693398


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 13: EEG2EEG loss: 0.9041448238970463, Corr : 0.696464581068474


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 14: EEG2EEG loss: 0.9031866018730621, Corr : 0.6866616994781021


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 15: EEG2EEG loss: 0.9032906973500316, Corr : 0.6930125692746975


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 16: EEG2EEG loss: 0.9035281640994479, Corr : 0.6985752471777675


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 17: EEG2EEG loss: 0.9036329492378973, Corr : 0.7010157951859523


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 18: EEG2EEG loss: 0.9024773290816773, Corr : 0.6983796774300691


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 19: EEG2EEG loss: 0.9023420097057778, Corr : 0.6934159563376069


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 20: EEG2EEG loss: 0.9013331895640094, Corr : 0.7006210165320176


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 21: EEG2EEG loss: 0.9020295093552979, Corr : 0.7025539670094895


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 22: EEG2EEG loss: 0.9012777230034023, Corr : 0.7001054808296777


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 23: EEG2EEG loss: 0.901124673954984, Corr : 0.6878575082448819


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 24: EEG2EEG loss: 0.9012341053158448, Corr : 0.7063353874517521


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 25: EEG2EEG loss: 0.9014715279094024, Corr : 0.7048917230055162


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 26: EEG2EEG loss: 0.9011505199470188, Corr : 0.6954085956111101


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 27: EEG2EEG loss: 0.9009805883153034, Corr : 0.7103058307042304


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 28: EEG2EEG loss: 0.9011442336169156, Corr : 0.7001860248470225


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 29: EEG2EEG loss: 0.9011346980842908, Corr : 0.706270559284249


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 30: EEG2EEG loss: 0.9007553198812548, Corr : 0.7027745323424738


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 31: EEG2EEG loss: 0.8999458945696773, Corr : 0.7019222591748986


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 32: EEG2EEG loss: 0.8998146954080815, Corr : 0.7033752070955445


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 33: EEG2EEG loss: 0.9001890218004267, Corr : 0.710484821921659


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 34: EEG2EEG loss: 0.9001307610847503, Corr : 0.7035444741463678


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 35: EEG2EEG loss: 0.900273003130863, Corr : 0.7024162243678544


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 36: EEG2EEG loss: 0.8990508655745472, Corr : 0.709553871469378


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 37: EEG2EEG loss: 0.8993355522303332, Corr : 0.6964440417451134


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 38: EEG2EEG loss: 0.8993058571271324, Corr : 0.7015033407844166


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 39: EEG2EEG loss: 0.8995622282332563, Corr : 0.7141306481508765


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 40: EEG2EEG loss: 0.8988117874821565, Corr : 0.7097580306085816


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 41: EEG2EEG loss: 0.8990288365279222, Corr : 0.7011719037455677


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 42: EEG2EEG loss: 0.8999267460299184, Corr : 0.7111128611672419


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 43: EEG2EEG loss: 0.8987837596373125, Corr : 0.719068605187741


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 44: EEG2EEG loss: 0.8985825499206263, Corr : 0.7022080054153299


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 45: EEG2EEG loss: 0.8984849714679459, Corr : 0.7074965528551226


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 46: EEG2EEG loss: 0.8982477004800129, Corr : 0.7166544925518247


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 47: EEG2EEG loss: 0.8981804973959231, Corr : 0.7195926182017646


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 48: EEG2EEG loss: 0.8980231646527636, Corr : 0.7111351552773832


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 49: EEG2EEG loss: 0.8981966252944916, Corr : 0.7102871534498965


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 50: EEG2EEG loss: 0.8982613367430247, Corr : 0.7075064221413309


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 51: EEG2EEG loss: 0.8977161515612205, Corr : 0.6771881911344011


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 52: EEG2EEG loss: 0.8974181304125537, Corr : 0.7172390061152565


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 53: EEG2EEG loss: 0.8972998292579872, Corr : 0.7203821460502623


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 54: EEG2EEG loss: 0.8971396131833702, Corr : 0.7076970497443901


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 55: EEG2EEG loss: 0.8965083169291513, Corr : 0.7163576081221431


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 56: EEG2EEG loss: 0.8974313746914412, Corr : 0.7124777824635433


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 57: EEG2EEG loss: 0.8966745870256332, Corr : 0.7099836491280532


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 58: EEG2EEG loss: 0.8973752873547994, Corr : 0.7107844006857086


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 59: EEG2EEG loss: 0.8965502666089705, Corr : 0.7199708008877765


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 60: EEG2EEG loss: 0.8963648961637191, Corr : 0.7223167588478026


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 61: EEG2EEG loss: 0.8964982725665464, Corr : 0.7234764500228644


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 62: EEG2EEG loss: 0.8967017831147525, Corr : 0.7105126203844551


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 63: EEG2EEG loss: 0.8964105470383421, Corr : 0.7092605294425818


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 64: EEG2EEG loss: 0.8965901703622401, Corr : 0.7189230430445771


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 65: EEG2EEG loss: 0.896148175519947, Corr : 0.7173954272806726


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 66: EEG2EEG loss: 0.8958788972989272, Corr : 0.7131546485540459


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 67: EEG2EEG loss: 0.8953462036712027, Corr : 0.7230169691701728


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 68: EEG2EEG loss: 0.8957064216326007, Corr : 0.7249545837996412


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 69: EEG2EEG loss: 0.8957740258663258, Corr : 0.7255312534496282


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 70: EEG2EEG loss: 0.8956210020086521, Corr : 0.7195490534930841


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 71: EEG2EEG loss: 0.8959216392478353, Corr : 0.7191151679265073


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 72: EEG2EEG loss: 0.894736898994538, Corr : 0.7116510629392659


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 73: EEG2EEG loss: 0.8949825050060708, Corr : 0.7232072125538922


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 74: EEG2EEG loss: 0.8948667508836411, Corr : 0.7186768577139514


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 75: EEG2EEG loss: 0.8949268386488496, Corr : 0.7126953053613606


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 76: EEG2EEG loss: 0.8945667069239589, Corr : 0.7171726010862006


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 77: EEG2EEG loss: 0.8950273048139864, Corr : 0.7213906537769637


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 78: EEG2EEG loss: 0.8938455515373607, Corr : 0.7119667234947313


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 79: EEG2EEG loss: 0.894052031768821, Corr : 0.7079314981611055


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 80: EEG2EEG loss: 0.8940318708147716, Corr : 0.7180088153888318


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 81: EEG2EEG loss: 0.8938322181397296, Corr : 0.7066139420361468


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 82: EEG2EEG loss: 0.8944329195718931, Corr : 0.706664685386649


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 83: EEG2EEG loss: 0.8937688494097564, Corr : 0.7087363828619465


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 84: EEG2EEG loss: 0.8936191718859645, Corr : 0.7149329105138242


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 85: EEG2EEG loss: 0.8942592242008482, Corr : 0.7141771949094103


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 86: EEG2EEG loss: 0.8932969625498848, Corr : 0.7148414528500429


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 87: EEG2EEG loss: 0.893167504950695, Corr : 0.7154540271066707


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 88: EEG2EEG loss: 0.8935598318304729, Corr : 0.717900849144678


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 89: EEG2EEG loss: 0.8929779266711592, Corr : 0.7129306597755004


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 90: EEG2EEG loss: 0.8928085804792621, Corr : 0.7079953927415796


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 91: EEG2EEG loss: 0.8921121078718101, Corr : 0.7148822456446915


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 92: EEG2EEG loss: 0.8925866049424361, Corr : 0.7128793272781531


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 93: EEG2EEG loss: 0.8930725386686897, Corr : 0.704027567653387


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 94: EEG2EEG loss: 0.8924950301301318, Corr : 0.7153692653151308


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 95: EEG2EEG loss: 0.8924382112242959, Corr : 0.7145097077754677


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 96: EEG2EEG loss: 0.8920229691831932, Corr : 0.7133919382577918


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 97: EEG2EEG loss: 0.8920965319214753, Corr : 0.7137511731403445


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 98: EEG2EEG loss: 0.8921664635032934, Corr : 0.7210899324013419


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 99: EEG2EEG loss: 0.8917474142010945, Corr : 0.7104873698599656


  0%|          | 0/517 [00:00<?, ?it/s]

Sub07 -> Sub08: Epoch 100: EEG2EEG loss: 0.8924008075227129, Corr : 0.7131009053359892


In [15]:
for subid1 in range(10):
    for subid2 in range(10):
        if subid1 != subid2 and subid1 != 8 and subid2 != 8 and subid1 == 6 and subid2 >= 7:
            data1 = np.load('eeg_data/train/sub' + str(subid1+1).zfill(2) + '.npy')
            mean1 = np.average(data1)
            std1 = np.std(data1)
            data1 = (data1-mean1)/std1
            data1 = np.transpose(data1, (1, 0, 2))
            data1 = torch.from_numpy(data1).float()
            data2 = np.load('eeg_data/train/sub' + str(subid2+1).zfill(2) + '.npy')
            mean2 = np.average(data2)
            std2 = np.std(data2)
            data2 = (data2-mean2)/std2
            data2 = np.transpose(data2, (1, 0, 2))
            data2 = torch.from_numpy(data2).float()
            torch_data = GetData(data1, data2)
            testsub1data = np.load('eeg_data/test/st_sub' + str(subid1+1).zfill(2) + '.npy')
            testsub1data = (testsub1data-mean1)/std1
            testsub1data = np.transpose(testsub1data, (1, 2, 0, 3))
            gen = UNet(input_dim, real_dim).to(device)
            gen.load_state_dict(torch.load(f"weights_nocosineloss/Sub{subid1+1}ToSub{subid2+1}_AllChannels/best_model.pth")['gen'])
            gen.eval()
            testsub2fakedata = np.zeros([200, 80, 17, 200])
            for i in range(200):
                test = torch.from_numpy(testsub1data[i]).float()
                test = test.to(device)
                testsub2fakedata[i] = gen(test).detach().cpu().numpy()
            np.save(f"generated_nocosineloss/st_Sub{subid1+1}ToSub{subid2+1}.npy", testsub2fakedata)