In [11]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.model_selection import train_test_split
import os.path as osp
from tqdm import tqdm
from PIL import Image

In [12]:
class LinearAutoEncoder(torch.nn.Module):
    def __init__(self, features: int, hidden: int, out_features:int):
        super().__init__()
        self.layers = nn.Sequential()
        inp = features
        while inp // 4 > out_features:
            self.layers.append(nn.Linear(inp, inp//4))
            self.layers.append(nn.ReLU())
            inp = inp // 4
            
        self.layers.append(nn.Linear(inp, hidden))
        self.relu = nn.ReLU()
        self.decoder = nn.Linear(hidden, out_features)

    def forward(self, x):
        x = self.layers(x)
        x = self.relu(x)
        return self.decoder(x)

In [13]:
file = '/raid/data/cats_dogs_dataset/PetImages/Cat/758.jpg'

In [14]:
data = np.array(Image.open(file).convert('L').resize((256, 256), Image.ANTIALIAS))
t = split_tensor(data)

NameError: name 'split_tensor' is not defined

In [15]:
256/8

32.0

In [16]:
##Create ModuleList
def create_layers(img_shape, split_coefs=None, hidden=16, out=10, tp = 'linear', isoclines=16):
    layers = nn.ModuleList()

    if split_coefs is None:
        split_coefs = [1, 2, 4]

    for coef in split_coefs:
        #Add simular pieces
        input_layers = (img_shape ** 2) // (16 * coef ** 2)
        layers.append(LinearAutoEncoder(input_layers, hidden, out))

    return layers

In [17]:
idx_list = [np.random.randint(1, 10) for _ in range(4)]
line = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
res = np.delete(line, idx_list)
print(idx_list, len(line), len(res))
res

[2, 8, 3, 7] 10 6


array([0, 1, 4, 5, 6, 9])

In [18]:
import random
def create_perc_split(img, amount_of_steps):
    previous = np.percentile(img, 100)
    draw_img = np.zeros_like(img)
    step = 100 / amount_of_steps
    for itt_num in range(0, amount_of_steps):
        border = np.percentile(img, 100 - step * (itt_num + 1))
        result = np.where((previous >= img) & (img > border), itt_num, 0)
        previous = border
        draw_img += result
    return draw_img

def prepare_data(img, mask, amount_of_steps):
    h, w = img.shape
    results = []
    pixel_per_isocline = h * w // amount_of_steps
    
    for i in range(amount_of_steps):
        line = img[mask == i]
        if line.shape[0] != pixel_per_isocline:
            idx_list = random.sample(range(1, line.shape[0] - 1), abs(pixel_per_isocline - line.shape[0]))
#             [np.random.randint(1, line.shape[0] - 1) for _ in range(abs(pixel_per_isocline - line.shape[0]))]
            is_insert = True if line.shape[0] < pixel_per_isocline else False
            
            if not is_insert:
                line = np.delete(line, idx_list)
            else:
                mean_val = line.mean()
                line = np.insert(line, idx_list, mean_val)
        results.append(line)
    return results

# Get fourier batches 
def split_tensor(imgs, split_coefs=None, amount_of_isoc=16):
    if split_coefs is None:
        split_coefs = [1, 2, 4, 8]
    result = []
    result_ph = []

    *_, h, w = imgs.shape
    last_val = 0
    for coef in split_coefs:
        step = h // coef
        row_data = []
        for y in range(0, w, step):
            step_result = []
            for x in range(0, h, step):
                #temperory unused
#                value = torch.log(1 + torch.abs(torch.fft.fftshift(torch.fft.fft2(imgs[..., x:x + step, y:y + step]))))
                value = np.abs(np.fft.fftshift(np.fft.fft2(imgs[..., x:x + step, y:y + step])))
                phase = np.angle(np.fft.fftshift(np.fft.fft2(imgs[..., x:x + step, y:y + step])))

                value -= value.min()
                value /= value.max()

                phase -= phase.min()
                phase /= phase.max()

                mask = create_perc_split(value, amount_of_steps=amount_of_isoc)
                data = prepare_data(value, mask, amount_of_steps=amount_of_isoc)
                data_ph = prepare_data(phase, mask, amount_of_steps=amount_of_isoc)


                result.append(data)
                result_ph.append(data_ph)
    return result, result_ph


In [24]:
#CAT - 0, DOG - 1
import os
from pathlib import Path
import glob
import numpy as np
from PIL import Image
bad_files = ['/raid/data/cats_dogs_dataset/PetImages/Cat/666.jpg', '/raid/data/cats_dogs_dataset/PetImages/Cat/Thumbs.db',
            '/raid/data/cats_dogs_dataset/PetImages/Dog/Thumbs.db', '/raid/data/cats_dogs_dataset/PetImages/Dog/11702.jpg',]
def dataset_preprocessing(directory, type_list = [('Cat', 0) , ('Dog', 1)], 
                          img_shape=(256,256), target_dirname='preprocessed_8'):
    result_data = []
    if not os.path.exists(f'/raid/data/cats_dogs_dataset/{target_dirname}'):
        os.mkdir(f'/raid/data/cats_dogs_dataset/{target_dirname}')
    
    for dir_name, target_type in type_list:
        for file in tqdm(glob.glob(f'{directory}/{dir_name}/*')):
            if file in bad_files:
                continue
            data = np.array(Image.open(file).convert('L').resize(img_shape, Image.ANTIALIAS))
            try:
                data_amp, data_phase = split_tensor(data)
            except:
                bad_files.append(file)
                continue
            if not os.path.exists(f'/raid/data/cats_dogs_dataset/{target_dirname}/{dir_name}'):
                os.mkdir(f'/raid/data/cats_dogs_dataset/{target_dirname}/{dir_name}')
            np.save(f'/raid/data/cats_dogs_dataset/{target_dirname}/{dir_name}/{Path(file).stem}.npy', 
                    [data_amp, data_phase, target_type])

In [25]:
224 * 224 / (16 * 16)

196.0

In [26]:
bad_files

['/raid/data/cats_dogs_dataset/PetImages/Cat/666.jpg',
 '/raid/data/cats_dogs_dataset/PetImages/Cat/Thumbs.db',
 '/raid/data/cats_dogs_dataset/PetImages/Dog/Thumbs.db',
 '/raid/data/cats_dogs_dataset/PetImages/Dog/11702.jpg']

In [27]:
drop_names = ['/raid/data/cats_dogs_dataset/PetImages/Cat/758.jpg', '/raid/data/cats_dogs_dataset/PetImages/Cat/3130.jpg', 
             '/raid/data/cats_dogs_dataset/PetImages/Cat/10151.jpg', '/raid/data/cats_dogs_dataset/PetImages/Cat/3216.jpg',
             '/raid/data/cats_dogs_dataset/PetImages/Cat/11819.jpg', '/raid/data/cats_dogs_dataset/PetImages/Cat/11928.jpg',
             '/raid/data/cats_dogs_dataset/PetImages/Cat/8811.jpg', '/raid/data/cats_dogs_dataset/PetImages/Cat/947.jpg', 
             '/raid/data/cats_dogs_dataset/PetImages/Cat/6145.jpg', '/raid/data/cats_dogs_dataset/PetImages/Cat/8456.jpg', 
             '/raid/data/cats_dogs_dataset/PetImages/Cat/4522.jpg', '/raid/data/cats_dogs_dataset/PetImages/Cat/7469.jpg',
             '/raid/data/cats_dogs_dataset/PetImages/Cat/2771.jpg', '/raid/data/cats_dogs_dataset/PetImages/Cat/11184.jpg', 
             '/raid/data/cats_dogs_dataset/PetImages/Cat/7396.jpg', '/raid/data/cats_dogs_dataset/PetImages/Cat/9765.jpg', 
             ]

In [None]:
dataset_preprocessing('/raid/data/cats_dogs_dataset/PetImages')

  arr = np.asanyarray(arr)
  phase /= phase.max()
  value /= value.max()
 17%|███████▊                                       | 2078/12501 [13:53<1:10:06,  2.48it/s]

In [102]:
def get_embedding(data, layers, iso_amount = 16, batch_size=2):
    node_embeddings = []
    min_value = 0
    
    it_val = 0
    for it_val, value in enumerate(data):
        value = value.reshape(batch_size, iso_amount, -1)
        if it_val == 0:
            layer_id = 0
        elif 1 <= it_val <= 4:
            layer_id = 1
        elif 5 <= it_val <= 21:
            layer_id = 2
        for i in range(iso_amount):
            node_embeddings.append(layers[layer_id](value[:, i, :]))
        
    return torch.stack(node_embeddings, dim=1)

In [90]:
def get_embedding(data, layers, iso_amount = 16):
    result = []
    for batch in data:
        min_value = 0
        node_embeddings = []
        it_val = 0
        for it_val, value in enumerate(batch):
            if it_val == 0:
                layer_id = 0
            elif 1 <= it_val <= 4:
                layer_id = 1
            elif 5 <= it_val <= 21:
                layer_id = 2
            for i in range(iso_amount):
                node_embeddings.append(layers[layer_id](value[i]))
                
        result.append(torch.stack(node_embeddings))
    return torch.stack(result)

In [103]:
SIZE = 256
NEURAL_TYPE = 'linear'
import torch.nn as nn
from torch_geometric.nn import GCNConv,SAGEConv
from torch_geometric.nn import global_mean_pool



class GCN(torch.nn.Module):
    def __init__(self, hidden_channels, num_classes, size=256, out=10, emb_type = 'linear'):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.neural_type = emb_type
        self.layers = create_layers(size, tp=NEURAL_TYPE, out=out)

        self.conv1 = SAGEConv(out, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, num_classes)
        self.soft = nn.Softmax(dim=1)


    def forward(self, x, edge_index, batch, batch_size):
        x = get_embedding(x, self.layers, iso_amount=16, batch_size=batch_size) # current linear
        #Maybe add this in get_embedding?
        x = x.view(batch_size * 336, -1)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        #x = F.dropout(x, p=0.1, training=self.training)
        x = self.lin(x)
        return self.soft(x)


In [None]:
17493

In [104]:
from torch_geometric.data import Dataset, download_url
from torch_geometric.data import Data

class MyOwnDataset(Dataset):
    def __init__(self, root, files_list, is_train, size=256, allow_loops=False, transform=None, pre_transform=None,
                 pre_filter=None, iso_amount=16, split_amount=21):
        self.data = files_list
        self.allow_loops = allow_loops
        self.is_train = is_train
        # TODO fix this shit
        self.iso_amount = iso_amount
        self.split_amount = split_amount
        self.gl_count = 1750 if self.is_train else 750
        super().__init__(root, transform, pre_transform, pre_filter)

    @property
    def raw_file_names(self):
        return []

    @property
    def num_nodes(self):
        return self.iso_amount * self.split_amount

    @property
    def processed_file_names(self):
        return [f'data_{idx}_is_train_{self.is_train}_loops={self.allow_loops}.pt' for idx in range(self.gl_count)]

    def _create_cco_matrix(self):
        vert_amount = int(self.iso_amount * self.split_amount)

        adj_matrix = np.zeros((vert_amount, vert_amount))
        for i in range(vert_amount - 1):
            if self.allow_loops:
                adj_matrix[i][i] = 1
            if i % self.iso_amount != 1:
                adj_matrix[i][i+1] = 1
            for j in range(self.split_amount):
                if self.iso_amount * j + i < vert_amount:
                    adj_matrix[i][self.iso_amount * j + i] = 1
                else:
                    break


        source_nodes = []
        target_nodes = []
        edge_list = []
        for iy, ix in np.ndindex(adj_matrix.shape):
            if adj_matrix[iy, ix] == 1:
                source_nodes.append(ix)
                target_nodes.append(iy)

                # unweighted solution
                edge_list.append(1)

        return source_nodes, target_nodes, edge_list

    def process(self):
        idx = 0
        source_vert, target_vert, edge_list = self._create_cco_matrix()
        edge_idx = torch.tensor([source_vert, target_vert])
        # DEBUG ROW
        _, small_data = train_test_split(self.data, test_size=0.1, random_state=42)
        train_data, test_data = train_test_split(small_data, test_size=0.3, random_state=42)
        data = train_data if self.is_train else test_data

        for file in data:
            # Read data from `raw_path`.
            amp, phase, target = np.load(file, allow_pickle=True)
            amp = [torch.tensor(item).float() for item in amp]

            #lst = [torch.from_numpy(item).float() for item in lst]

            # temporary only amp|
            data = Data(x=amp,
                        # edge_index=torch.tensor(edge_idx).clone().detach().float().requires_grad_(True),
                        edge_index=edge_idx,
                        edge_attrs=edge_list,
                        y=torch.tensor([target]))

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}_is_train_{self.is_train}_loops={self.allow_loops}.pt'))
            idx += 1
        self.gl_count = idx

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}_is_train_{self.is_train}_loops={self.allow_loops}.pt'))
        return data


In [105]:
import glob

In [112]:
save_root = '/raid/data/cats_dogs_dataset/'
files = glob.glob('/raid/data/cats_dogs_dataset/preprocessed/*/*.npy', recursive=True)
train_dataset = MyOwnDataset(save_root, files, is_train = True)
test_dataset = MyOwnDataset(save_root, files, is_train = False)

from torch_geometric.loader import DataLoader

device = torch.device('cuda:0')
save_root = '/raid/data/cats_dogs_dataset/'
files = glob.glob('/raid/data/cats_dogs_dataset/preprocessed/*/*.npy', recursive=True)
train_dataset = MyOwnDataset(save_root, files, is_train = True)
test_dataset = MyOwnDataset(save_root, files, is_train = False)

model = GCN(num_classes=2, hidden_channels = 10).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 0.001, weight_decay = 0.001)
criterion = torch.nn.CrossEntropyLoss()

batch_size = 20
train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False, drop_last=True)

In [114]:
# weights_stem = 'GNN_Linear_No_loops_bsize=1'
# wandb.run.name = weights_stem


#                 experiment.log_metric("train_dice_loss", batch_loss.item(),
#                                       epoch=epoch_idx, step=step_counter[action])
def train():
    model.train()
    itt = 0
    for itt, data in enumerate(train_loader):  # Iterate in batches over the training dataset.
#         out = model(data.x, data.edge_index, data.batch) 
        optimizer.zero_grad()  # Clear gradients.

        data = data.to(device)
#          print(data.x.shape)
        
        #
        out = model(data.x, data.edge_index, data.batch, batch_size)  # Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        if itt % 100 == 0:
            print(loss.item())
#             wandb.log({"train_cross_entropy_loss": loss.item()})
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.

def test(loader):
    model.eval()
    correct = 0
    for itt, data in enumerate(loader):  # %Iterate in batches over the training/test dataset.
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch, batch_size) # Perform a single forward pass.
        data = data.to(device)
        pred = out.argmax(dim=1)# Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
#         correct += int((pred == data.y>.squeeze().unsqueeze(0).float()).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.

best_train, cur_epoch, best_val = -1, -1, -1 
for epoch in tqdm(range(1, 50)):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    best_train, cur_epoch, best_val = (train_acc, epoch, test_acc) if test_acc > best_val \
                                    else (best_train, cur_epoch, best_val)
    #wandb.log({'best train accuracy': best_train})
    #wandb.log({'best test accuracy': best_val})
    
    #wandb.log({'current train accuracy': train_acc})
    #wandb.log({'current test accuracy': test_acc})

    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

print(f'Best test accuracy - {best_val} on epoch {cur_epoch} with train accuracy -ч> {best_train}')

  0%|                                                    | 0/49 [00:00<?, ?it/s]

0.6948073506355286


  2%|▊                                        | 1/49 [01:50<1:28:32, 110.67s/it]

Epoch: 001, Train Acc: 0.5217, Test Acc: 0.4920
0.6699774265289307


  4%|█▋                                       | 2/49 [03:41<1:26:52, 110.91s/it]

Epoch: 002, Train Acc: 0.5229, Test Acc: 0.4920
0.6886707544326782


  6%|██▌                                      | 3/49 [05:32<1:25:08, 111.05s/it]

Epoch: 003, Train Acc: 0.5223, Test Acc: 0.4920
0.6941331028938293


  8%|███▎                                     | 4/49 [07:24<1:23:17, 111.06s/it]

Epoch: 004, Train Acc: 0.5229, Test Acc: 0.4920
0.6746267080307007


 10%|████▏                                    | 5/49 [09:13<1:21:07, 110.63s/it]

Epoch: 005, Train Acc: 0.5394, Test Acc: 0.5187
0.6726589798927307


 12%|█████                                    | 6/49 [11:04<1:19:18, 110.66s/it]

Epoch: 006, Train Acc: 0.5737, Test Acc: 0.5347
0.722974419593811


 14%|█████▊                                   | 7/49 [12:55<1:17:35, 110.85s/it]

Epoch: 007, Train Acc: 0.5829, Test Acc: 0.5493
0.7246370315551758


 16%|██████▋                                  | 8/49 [14:47<1:15:56, 111.12s/it]

Epoch: 008, Train Acc: 0.5971, Test Acc: 0.5533
0.7205384969711304


 18%|███████▌                                 | 9/49 [16:38<1:14:06, 111.17s/it]

Epoch: 009, Train Acc: 0.6069, Test Acc: 0.5627
0.7775242328643799


 20%|████████▏                               | 10/49 [18:29<1:12:13, 111.11s/it]

Epoch: 010, Train Acc: 0.6057, Test Acc: 0.5640
0.5820300579071045


 22%|████████▉                               | 11/49 [20:20<1:10:22, 111.12s/it]

Epoch: 011, Train Acc: 0.6023, Test Acc: 0.5707
0.6598140001296997


 24%|█████████▊                              | 12/49 [22:11<1:08:30, 111.08s/it]

Epoch: 012, Train Acc: 0.6080, Test Acc: 0.5707
0.6984009146690369


 27%|██████████▌                             | 13/49 [24:01<1:06:22, 110.63s/it]

Epoch: 013, Train Acc: 0.6103, Test Acc: 0.5747
0.5728474855422974


 29%|███████████▍                            | 14/49 [25:51<1:04:21, 110.33s/it]

Epoch: 014, Train Acc: 0.6320, Test Acc: 0.5667
0.6139127016067505


 31%|████████████▏                           | 15/49 [27:40<1:02:20, 110.02s/it]

Epoch: 015, Train Acc: 0.6377, Test Acc: 0.5640
0.6764341592788696


 33%|█████████████                           | 16/49 [29:29<1:00:25, 109.86s/it]

Epoch: 016, Train Acc: 0.6429, Test Acc: 0.5853
0.5836280584335327


 35%|██████████████▌                           | 17/49 [31:19<58:30, 109.69s/it]

Epoch: 017, Train Acc: 0.6383, Test Acc: 0.5707
0.6432305574417114


 37%|███████████████▍                          | 18/49 [33:08<56:32, 109.43s/it]

Epoch: 018, Train Acc: 0.6486, Test Acc: 0.5653
0.6612120866775513


 39%|████████████████▎                         | 19/49 [34:57<54:38, 109.27s/it]

Epoch: 019, Train Acc: 0.6383, Test Acc: 0.5747
0.5992501974105835


 41%|█████████████████▏                        | 20/49 [36:46<52:47, 109.23s/it]

Epoch: 020, Train Acc: 0.6611, Test Acc: 0.5720
0.6421306729316711


 43%|██████████████████                        | 21/49 [38:35<50:57, 109.19s/it]

Epoch: 021, Train Acc: 0.6640, Test Acc: 0.5760
0.701310932636261


 45%|██████████████████▊                       | 22/49 [40:24<49:07, 109.16s/it]

Epoch: 022, Train Acc: 0.6509, Test Acc: 0.5760
0.5510022044181824


 47%|███████████████████▋                      | 23/49 [42:12<47:14, 109.01s/it]

Epoch: 023, Train Acc: 0.6691, Test Acc: 0.5773
0.7113203406333923


 49%|████████████████████▌                     | 24/49 [44:02<45:25, 109.02s/it]

Epoch: 024, Train Acc: 0.6789, Test Acc: 0.5613
0.544886589050293


 51%|█████████████████████▍                    | 25/49 [45:51<43:36, 109.02s/it]

Epoch: 025, Train Acc: 0.7126, Test Acc: 0.5840
0.6499166488647461


 53%|██████████████████████▎                   | 26/49 [47:40<41:47, 109.02s/it]

Epoch: 026, Train Acc: 0.7149, Test Acc: 0.5813
0.5749301314353943


 55%|███████████████████████▏                  | 27/49 [49:29<39:59, 109.05s/it]

Epoch: 027, Train Acc: 0.6823, Test Acc: 0.5680
0.698563277721405


 57%|████████████████████████                  | 28/49 [51:18<38:09, 109.04s/it]

Epoch: 028, Train Acc: 0.7429, Test Acc: 0.5853
0.5486679077148438


 59%|████████████████████████▊                 | 29/49 [53:07<36:20, 109.01s/it]

Epoch: 029, Train Acc: 0.7434, Test Acc: 0.5720
0.5777146220207214


 61%|█████████████████████████▋                | 30/49 [54:56<34:32, 109.09s/it]

Epoch: 030, Train Acc: 0.7663, Test Acc: 0.5720
0.6211085319519043


 63%|██████████████████████████▌               | 31/49 [56:45<32:43, 109.10s/it]

Epoch: 031, Train Acc: 0.8029, Test Acc: 0.5787
0.5515869855880737


 65%|███████████████████████████▍              | 32/49 [58:34<30:55, 109.13s/it]

Epoch: 032, Train Acc: 0.7891, Test Acc: 0.5693
0.4350293278694153


 67%|██████████████████████████▉             | 33/49 [1:00:23<29:05, 109.12s/it]

Epoch: 033, Train Acc: 0.8240, Test Acc: 0.5800
0.4819869101047516


 69%|███████████████████████████▊            | 34/49 [1:02:12<27:15, 109.03s/it]

Epoch: 034, Train Acc: 0.8080, Test Acc: 0.5600
0.6432255506515503


 71%|████████████████████████████▌           | 35/49 [1:04:01<25:25, 108.99s/it]

Epoch: 035, Train Acc: 0.8451, Test Acc: 0.5467
0.33555835485458374


 73%|█████████████████████████████▍          | 36/49 [1:05:50<23:36, 108.97s/it]

Epoch: 036, Train Acc: 0.8606, Test Acc: 0.5720
0.33968883752822876


 76%|██████████████████████████████▏         | 37/49 [1:07:39<21:47, 108.98s/it]

Epoch: 037, Train Acc: 0.8863, Test Acc: 0.5600
0.3199881911277771


 78%|███████████████████████████████         | 38/49 [1:09:28<19:58, 108.97s/it]

Epoch: 038, Train Acc: 0.8926, Test Acc: 0.5547
0.5280546545982361


 80%|███████████████████████████████▊        | 39/49 [1:11:17<18:09, 108.96s/it]

Epoch: 039, Train Acc: 0.9040, Test Acc: 0.5600
0.3717808723449707


 82%|████████████████████████████████▋       | 40/49 [1:13:06<16:20, 108.89s/it]

Epoch: 040, Train Acc: 0.8686, Test Acc: 0.5600
0.3677230179309845


 84%|█████████████████████████████████▍      | 41/49 [1:14:54<14:30, 108.84s/it]

Epoch: 041, Train Acc: 0.8949, Test Acc: 0.5613
0.3142542839050293


 86%|██████████████████████████████████▎     | 42/49 [1:16:43<12:41, 108.78s/it]

Epoch: 042, Train Acc: 0.9097, Test Acc: 0.5667
0.375369131565094


 86%|██████████████████████████████████▎     | 42/49 [1:17:01<12:50, 110.05s/it]


KeyboardInterrupt: 

In [86]:
def get_embedding_test(data, layers, iso_amount=16, batch_size=2):
    result, node_embeddings = [], []
    min_value = 0
    
    it_val = 0
    for it_val, value in enumerate(data):
        print(value.shape)
        value = value.reshape(batch_size, iso_amount, -1)
        if it_val == 0:
            layer_id = 0
        elif 1 <= it_val <= 4:
            layer_id = 1
        elif 5 <= it_val <= 21:
            layer_id = 2
        for i in range(iso_amount):
            print(value.shape)
            node_embeddings.append(layers[layer_id](value[:, i, :]))
        
    return torch.stack(node_embeddings, dim=1)

In [88]:
for itt, data in enumerate(train_loader):  # Iterate in batches over the training dataset.
    z = get_embedding(data.x, model.layers, iso_amount=16, batch_size=2)
    print(z.shape)
    break

TypeError: get_embedding() got an unexpected keyword argument 'batch_size'