In [None]:
import os
import urllib.request
import zipfile
import nibabel as nib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
from torchvision import transforms

from fastai.vision.all import *

from google.colab import drive

from tqdm.notebook import tqdm

#Brain Activity Patterns in Learning Processes

a nerual network to classify different stages of learning (e.g., early, middle, late) based on fMRI data. This could provide insights into how the brain's activity changes as we acquire new skills or knowledge.

dataset: [Classification learning](https://openfmri.org/dataset/ds000002/) from OpenfMRI

Download and extract the dataset

In [None]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
base_path = '/content/drive/MyDrive/learnedSpectrum'
zip_path = os.path.join(base_path, "ds000002_R2.0.5_raw.zip")
extract_path = os.path.join(base_path, "fmri_data")

In [None]:
url = "https://s3.amazonaws.com/openneuro/ds000002/ds000002_R2.0.5/compressed/ds000002_R2.0.5_raw.zip"

In [None]:
print("Downloading dataset...")
urllib.request.urlretrieve(url, zip_path)

Downloading dataset...


('/content/drive/MyDrive/learnedSpectrum/ds000002_R2.0.5_raw.zip',
 <http.client.HTTPMessage at 0x7cd24559feb0>)

In [None]:
print("Extracting dataset...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

Extracting dataset...


In [None]:
print("Dataset extracted to:", extract_path)

Dataset extracted to: /content/drive/MyDrive/learnedSpectrum/fmri_data


explore directory structure

In [None]:
def explore_directory(path, level=0):
    print("|   " * level + "+--" + os.path.basename(path))
    if os.path.isdir(path):
        for item in os.listdir(path):
            item_path = os.path.join(path, item)
            if os.path.isdir(item_path):
                explore_directory(item_path, level + 1)
            elif item.endswith('.nii.gz'):
                print("|   " * (level + 1) + "+--" + item)

In [None]:
print("\nDataset structure:")
explore_directory(extract_path)


Dataset structure:
+--fmri_data
|   +--ds002_R2.0.5
|   |   +--sub-03
|   |   |   +--anat
|   |   |   |   +--sub-03_T1w.nii.gz
|   |   |   |   +--sub-03_inplaneT2.nii.gz
|   |   |   +--func
|   |   |   |   +--sub-03_task-mixedeventrelatedprobe_run-01_bold.nii.gz
|   |   |   |   +--sub-03_task-probabilisticclassification_run-01_bold.nii.gz
|   |   |   |   +--sub-03_task-mixedeventrelatedprobe_run-02_bold.nii.gz
|   |   |   |   +--sub-03_task-deterministicclassification_run-01_bold.nii.gz
|   |   |   |   +--sub-03_task-deterministicclassification_run-02_bold.nii.gz
|   |   |   |   +--sub-03_task-probabilisticclassification_run-02_bold.nii.gz
|   |   +--sub-08
|   |   |   +--anat
|   |   |   |   +--sub-08_T1w.nii.gz
|   |   |   |   +--sub-08_inplaneT2.nii.gz
|   |   |   +--func
|   |   |   |   +--sub-08_task-mixedeventrelatedprobe_run-02_bold.nii.gz
|   |   |   |   +--sub-08_task-deterministicclassification_run-01_bold.nii.gz
|   |   |   |   +--sub-08_task-deterministicclassification_run

load and display basic info about NIfTI files

In [None]:
def display_nifti_info(file_path):
    img = nib.load(file_path)
    print(f"File: {os.path.basename(file_path)}")
    print(f"Shape: {img.shape}")
    print(f"Data type: {img.get_data_dtype()}")
    print(f"Header info:")
    print(img.header)
    print("\n")

Display info for a sample NIfTI file

In [None]:
sample_nifti = None
for root, dirs, files in os.walk(extract_path):
    for file in files:
        if file.endswith('.nii.gz'):
            sample_nifti = os.path.join(root, file)
            break
    if sample_nifti:
        break

if sample_nifti:
    print("\nSample NIfTI file info:")
    display_nifti_info(sample_nifti)
else:
    print("No NIfTI files found in the dataset.")


Sample NIfTI file info:
File: sub-03_T1w.nii.gz
Shape: (160, 192, 192)
Data type: int16
Header info:
<class 'nibabel.nifti1.Nifti1Header'> object, endian='<'
sizeof_hdr      : 348
data_type       : b''
db_name         : b''
extents         : 0
session_error   : 0
regular         : b''
dim_info        : 0
dim             : [  3 160 192 192   1   1   1   1]
intent_p1       : 0.0
intent_p2       : 0.0
intent_p3       : 0.0
intent_code     : none
datatype        : int16
bitpix          : 16
slice_start     : 0
pixdim          : [-1.         1.         1.3333333  1.3333333  0.         1.
  1.         1.       ]
vox_offset      : 0.0
scl_slope       : nan
scl_inter       : nan
slice_end       : 0
slice_code      : unknown
xyzt_units      : 10
cal_max         : 0.0
cal_min         : 0.0
slice_duration  : 0.0
toffset         : 0.0
glmax           : 0
glmin           : 0
descrip         : b'FreeSurfer Aug 11 2009'
aux_file        : b''
qform_code      : scanner
sform_code      : scanner
quater

Data augmentation

In [None]:
class RandomNoise(object):
    def __init__(self, mean=0, std=0.1):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

Define preprocessing function

In [None]:
def preprocess_fmri(img):
    try:
        print(f"Original shape: {img.shape}")

        if len(img.shape) != 4:
            raise ValueError(f"Expected 4D input, got {len(img.shape)}D")

        if img.shape[3] == 0:
            img = np.zeros(img.shape[:3] + (1,))

        mean = img.mean(axis=3, keepdims=True)
        std = img.std(axis=3, keepdims=True)
        img = (img - mean) / (std + 1e-8)

        img = img.mean(axis=3)

        target_shape = (64, 64, 64)
        current_shape = img.shape

        pad_width = []
        for t, c in zip(target_shape, current_shape):
            if t > c:
                pad_before = (t - c) // 2
                pad_after = t - c - pad_before
                pad_width.append((pad_before, pad_after))
            elif t < c:
                crop_before = (c - t) // 2
                crop_after = c - t - crop_before
                pad_width.append((-crop_before, -crop_after))
            else:
                pad_width.append((0, 0))

        img = np.pad(img, pad_width, mode='constant', constant_values=0)

        img = img[:64, :64, :64]

        img = img[np.newaxis, ...]

        img_tensor = torch.from_numpy(img).float()

        print(f"Processed shape: {img_tensor.shape}")
        return img_tensor

    except Exception as e:
        print(f"Error in preprocess_fmri: {str(e)}")
        return torch.zeros((1, 64, 64, 64))

Define the fMRIDataset class

In [None]:
class fMRIDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = os.path.join(root_dir, 'ds002_R2.0.5')
        self.samples = self._load_samples()

    def _load_samples(self):
        samples = []
        for subject in os.listdir(self.root_dir):
            subject_path = os.path.join(self.root_dir, subject)
            if os.path.isdir(subject_path) and subject.startswith('sub-'):
                func_path = os.path.join(subject_path, 'func')
                if os.path.exists(func_path):
                    nifti_files = [f for f in os.listdir(func_path) if f.endswith('_bold.nii.gz')]
                    for file in nifti_files:
                        file_path = os.path.join(func_path, file)
                        task = file.split('task-')[1].split('_')[0]
                        run = int(file.split('run-')[1].split('_')[0])

                        if 'deterministic' in task:
                            label = 0
                        elif 'probabilistic' in task:
                            label = 1
                        else:
                            label = 2

                        samples.append((file_path, label, task))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path, label, task = self.samples[idx]
        try:
            img = nib.load(file_path).get_fdata()
            img = preprocess_fmri(img)
            return img, label, task
        except Exception as e:
            print(f"Error loading {file_path}: {str(e)}")
            # Return a dummy sample in case of error
            return torch.zeros((1, 64, 64, 64)), -1, "error"

In [None]:
root_dir = '/content/drive/MyDrive/learnedSpectrum/fmri_data'
dataset = fMRIDataset(root_dir=root_dir)

print(f"Dataset length: {len(dataset)}")

Dataset length: 102


In [None]:
for i in range(min(5, len(dataset))):
    try:
        img, label, task = dataset[i]
        print(f"Sample {i}:")
        print(f"  Image shape: {img.shape}")
        if img.numel() > 0:
            print(f"  Image min: {img.min().item():.4f}, max: {img.max().item():.4f}, mean: {img.mean().item():.4f}, std: {img.std().item():.4f}")
        else:
            print("  Image is empty")
        print(f"  Label: {label}")
        print(f"  Task: {task}")
    except Exception as e:
        print(f"Error loading sample {i}: {str(e)}")

Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Sample 0:
  Image shape: torch.Size([1, 64, 64, 64])
  Image min: -0.0000, max: 0.0000, mean: 0.0000, std: 0.0000
  Label: 2
  Task: mixedeventrelatedprobe
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Sample 1:
  Image shape: torch.Size([1, 64, 64, 64])
  Image min: -0.0000, max: 0.0000, mean: -0.0000, std: 0.0000
  Label: 1
  Task: probabilisticclassification
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Sample 2:
  Image shape: torch.Size([1, 64, 64, 64])
  Image min: -0.0000, max: 0.0000, mean: 0.0000, std: 0.0000
  Label: 2
  Task: mixedeventrelatedprobe
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Sample 3:
  Image shape: torch.Size([1, 64, 64, 64])
  Image min: -0.0000, max: 0.0000, mean: -0.0000, std: 0.0000
  Label: 0
  Task: deterministicclassification
Original shape: (64, 64, 25, 180)
Processed shap

In [None]:
batch_size = 4
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print("\nTesting DataLoader:")
for i, (batch_img, batch_label, batch_task) in enumerate(dataloader):
    print(f"Batch {i}:")
    print(f"  Batch image shape: {batch_img.shape}")
    if batch_img.numel() > 0:
        print(f"  Batch image min: {batch_img.min().item():.4f}, max: {batch_img.max().item():.4f}, mean: {batch_img.mean().item():.4f}, std: {batch_img.std().item():.4f}")
    else:
        print("  Batch image is empty")
    print(f"  Batch label shape: {batch_label.shape}")
    print(f"  Batch task: {batch_task}")
    if i == 2:
        break


Testing DataLoader:
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Batch 0:
  Batch image shape: torch.Size([4, 1, 64, 64, 64])
  Batch image min: -0.0000, max: 0.0000, mean: 0.0000, std: 0.0000
  Batch label shape: torch.Size([4])
  Batch task: ('mixedeventrelatedprobe', 'deterministicclassification', 'probabilisticclassification', 'deterministicclassification')
Original shape: (64, 64, 30, 234)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 234)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Batch 1:
  Batch image shape: torch.Size([4, 1, 

Define the 3D CNN model

In [None]:
class ImprovedFMRICNN(nn.Module):
    def __init__(self, num_classes):
        super(ImprovedFMRICNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),

            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),

            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool3d(kernel_size=2, stride=2),
        )

        self.classifier = nn.Sequential(
            nn.Linear(128 * 8 * 8 * 8, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

Learning rate scheduler

In [None]:
def get_lr_scheduler(optimizer):
    return optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

Early stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

Training function

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, scheduler):
    model.to(device)
    best_val_acc = 0.0
    best_model = None

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for inputs, labels, _ in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            train_total += labels.size(0)
            train_correct += predicted.eq(labels).sum().item()

        train_loss = train_loss / len(train_loader.dataset)
        train_acc = train_correct / train_total

        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels, _ in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

        val_loss = val_loss / len(val_loader.dataset)
        val_acc = val_correct / val_total

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}')

        scheduler.step(val_loss)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model = model.state_dict()

    model.load_state_dict(best_model)
    return model

In [None]:
print("First few samples:")
for i in range(min(5, len(dataset))):
    print(f"Sample {i}: {dataset.samples[i]}")

if len(dataset) == 0:
    print("Error: The dataset is empty. Please check the following:")
    print(f"1. The root_dir '{root_dir}' exists and contains the 'ds002_R2.0.5' folder.")
    print("2. The directory structure matches what's expected in the _load_samples method.")
    print("3. There are .nii.gz files in the func subdirectories.")

First few samples:
Sample 0: ('/content/drive/MyDrive/learnedSpectrum/fmri_data/ds002_R2.0.5/sub-03/func/sub-03_task-mixedeventrelatedprobe_run-01_bold.nii.gz', 2, 'mixedeventrelatedprobe')
Sample 1: ('/content/drive/MyDrive/learnedSpectrum/fmri_data/ds002_R2.0.5/sub-03/func/sub-03_task-probabilisticclassification_run-01_bold.nii.gz', 1, 'probabilisticclassification')
Sample 2: ('/content/drive/MyDrive/learnedSpectrum/fmri_data/ds002_R2.0.5/sub-03/func/sub-03_task-mixedeventrelatedprobe_run-02_bold.nii.gz', 2, 'mixedeventrelatedprobe')
Sample 3: ('/content/drive/MyDrive/learnedSpectrum/fmri_data/ds002_R2.0.5/sub-03/func/sub-03_task-deterministicclassification_run-01_bold.nii.gz', 0, 'deterministicclassification')
Sample 4: ('/content/drive/MyDrive/learnedSpectrum/fmri_data/ds002_R2.0.5/sub-03/func/sub-03_task-deterministicclassification_run-02_bold.nii.gz', 0, 'deterministicclassification')


Balance the dataset

In [None]:
label_counts = [0, 0, 0]
for _, label, _ in dataset:
    if label != -1:
        label_counts[label] += 1

class_weights = [len(dataset) / (3 * count) for count in label_counts]
sample_weights = [class_weights[label] if label != -1 else 0 for _, label, _ in dataset]

Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: t

Split the data

In [None]:
indices = list(range(len(dataset)))

In [None]:
train_val_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
train_indices, val_indices = train_test_split(train_val_indices, test_size=0.2, random_state=42)

weighted samplers

In [None]:
train_sampler = WeightedRandomSampler([sample_weights[i] for i in train_indices], len(train_indices))
val_sampler = WeightedRandomSampler([sample_weights[i] for i in val_indices], len(val_indices))

data loaders

In [None]:
batch_size = 16
train_loader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2)
val_loader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, num_workers=2)
test_loader = DataLoader(dataset, batch_size=batch_size, sampler=test_indices, num_workers=2)

Initialize model, loss, and optimizer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ImprovedFMRICNN(num_classes=3)
criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights).to(device))
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)



Train the model

In [None]:
num_epochs = 50
trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs, device, scheduler)

Epoch 1/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 2/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 3/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])Processed shape: torch.Size([1, 64, 64, 64])

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 4/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (6

Epoch 5/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)Original shape: (64, 64, 25, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])Original shape: (64, 64, 30, 232)

Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 6/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 7/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 8/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: t

Epoch 9/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])Processed shape: torch.Size([1, 64, 64, 64])

Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)Original shape: (64, 64, 25, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 10/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: t

Epoch 11/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: t

Epoch 12/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 13/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)Original shape: (64, 64, 25, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 14/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 15/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)Original shape: (64, 64, 30, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 16/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 17/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 18/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 19/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])Original shape: (64, 64, 30, 232)

Processed shape: t

Epoch 20/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 21/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)Original shape: (64, 64, 30, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 22/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 23/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (6

Epoch 24/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 25/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])Processed shape: torch.Size([1, 64, 64, 64])

Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])Processed shape: torch.Size([1, 64, 64, 64])

Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 26/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: t

Epoch 27/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)Original shape: (64, 64, 25, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 28/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: t

Epoch 29/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)Original shape: (64, 64, 30, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (6

Epoch 30/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)Original shape: (64, 64, 30, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 31/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])Processed shape: torch.Size([1, 64, 64, 64])

Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 32/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 33/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)Processed shape: torch.Size([1, 64, 64, 64])

Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 34/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 35/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 36/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 37/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 38/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 39/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)Original shape: (64, 64, 30, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: t

Epoch 40/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])Processed shape: torch.Size([1, 64, 64, 64])

Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)Original shape: (64, 64, 30, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 41/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: t

Epoch 42/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)Original shape: (64, 64, 25, 237)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)Original shape: (64, 64, 25, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 43/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)Original shape: (64, 64, 30, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 44/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)Original shape: (64, 64, 25, 180)

Processed shape: torch.Size([1, 64, 64, 64])Processed shape: torch.Size([1, 64, 64, 64])

Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 45/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: t

Epoch 46/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 47/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)Processed shape: torch.Size([1, 64, 64, 64])

Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 48/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 49/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])Processed shape: torch.Size([1, 64, 64, 64])

Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (64, 64, 25, 180)Processed shape: torch.Size([1, 64, 64, 64])

Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Epoch 50/50:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)Original shape: (64, 64, 25, 180)

Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Processed shape: t

Evaluate the model

In [None]:
trained_model.eval()
all_preds = []
all_labels = []
all_tasks = []

with torch.no_grad():
    for inputs, labels, tasks in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = trained_model(inputs)
        _, predicted = outputs.max(1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_tasks.extend(tasks)

Evaluating:   0%|          | 0/7 [00:00<?, ?it/s]

Original shape: (64, 64, 25, 237)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 237)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Original shape: (64, 64, 30, 232)
Processed shape: torch.Size([1, 64, 64, 64])
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 25, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 180)
Processed shape: torch.Size([1, 64, 64, 64])
Original shape: (64, 64, 30, 232)
Original shape: (6

Print classification report

In [None]:
print(classification_report(all_labels, all_preds, target_names=['Deterministic', 'Probabilistic', 'Mixed']))

               precision    recall  f1-score   support

Deterministic       0.33      1.00      0.50        34
Probabilistic       0.00      0.00      0.00        34
        Mixed       0.00      0.00      0.00        34

     accuracy                           0.33       102
    macro avg       0.11      0.33      0.17       102
 weighted avg       0.11      0.33      0.17       102



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
