In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet18

class FiveResNet18MLP5(nn.Module):
    def __init__(self):
        super(FiveResNet18MLP5, self).__init__()

        # Current Images Set
        # ResNet1
        self.current_resnet1 = resnet18(weights=None)
        self.current_resnet1 = nn.Sequential(*list(self.current_resnet1.children())[:-1])
        # ResNet2
        self.current_resnet2 = resnet18(weights=None)
        self.current_resnet2 = nn.Sequential(*list(self.current_resnet2.children())[:-1])
        # ResNet3
        self.current_resnet3 = resnet18(weights=None)
        self.current_resnet3 = nn.Sequential(*list(self.current_resnet3.children())[:-1])
        # ResNet4
        self.current_resnet4 = resnet18(weights=None)
        self.current_resnet4 = nn.Sequential(*list(self.current_resnet4.children())[:-1])
        # ResNet5
        self.current_resnet5 = resnet18(weights=None)
        self.current_resnet5 = nn.Sequential(*list(self.current_resnet5.children())[:-1])

        # Goal Images Set
        # ResNet1
        self.goal_resnet1 = resnet18(weights=None)
        self.goal_resnet1 = nn.Sequential(*list(self.goal_resnet1.children())[:-1])
        # ResNet2
        self.goal_resnet2 = resnet18(weights=None)
        self.goal_resnet2 = nn.Sequential(*list(self.goal_resnet2.children())[:-1])
        # ResNet3
        self.goal_resnet3 = resnet18(weights=None)
        self.goal_resnet3 = nn.Sequential(*list(self.goal_resnet3.children())[:-1])
        # ResNet4
        self.goal_resnet4 = resnet18(weights=None)
        self.goal_resnet4 = nn.Sequential(*list(self.goal_resnet4.children())[:-1])
        # ResNet5
        self.goal_resnet5 = resnet18(weights=None)
        self.goal_resnet5 = nn.Sequential(*list(self.goal_resnet5.children())[:-1])

        # MLP Layers
        self.fc_layer1 = nn.Sequential(
            nn.Linear(5120, 1024),
            nn.ReLU())
        self.fc_layer2 = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.ReLU())
        self.fc_layer3 = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.ReLU())
        self.fc_layer4 = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.ReLU())
        self.fc_layer5 = nn.Linear(1024, 1)

    def forward(self, current_images, goal_images):

        # Forward pass through ResNet
        current_embedding1 = self.current_resnet1(current_images[:, 0, :, :])
        current_embedding1 = torch.flatten(current_embedding1, start_dim=1)
        current_embedding2 = self.current_resnet2(current_images[:, 1, :, :])
        current_embedding2 = torch.flatten(current_embedding2, start_dim=1)
        current_embedding3 = self.current_resnet3(current_images[:, 2, :, :])
        current_embedding3 = torch.flatten(current_embedding3, start_dim=1)
        current_embedding4 = self.current_resnet4(current_images[:, 3, :, :])
        current_embedding4 = torch.flatten(current_embedding4, start_dim=1)
        current_embedding5 = self.current_resnet5(current_images[:, 4, :, :])
        current_embedding5 = torch.flatten(current_embedding5, start_dim=1)

        # Forward pass through ResNet
        goal_embedding1 = self.goal_resnet1(goal_images[:, 0, :, :])
        goal_embedding1 = torch.flatten(goal_embedding1, start_dim=1)
        goal_embedding2 = self.goal_resnet2(goal_images[:, 1, :, :])
        goal_embedding2 = torch.flatten(goal_embedding2, start_dim=1)
        goal_embedding3 = self.goal_resnet3(goal_images[:, 2, :, :])
        goal_embedding3 = torch.flatten(goal_embedding3, start_dim=1)
        goal_embedding4 = self.goal_resnet4(goal_images[:, 3, :, :])
        goal_embedding4 = torch.flatten(goal_embedding4, start_dim=1)
        goal_embedding5 = self.goal_resnet5(goal_images[:, 4, :, :])
        goal_embedding5 = torch.flatten(goal_embedding5, start_dim=1)

        # Concatenate the features
        current_features = torch.cat((current_embedding1, current_embedding2, current_embedding3, current_embedding4, current_embedding5), dim=1)
        goal_features = torch.cat((goal_embedding1, goal_embedding2, goal_embedding3, goal_embedding4, goal_embedding5), dim=1)
        features = torch.cat([current_features, goal_features], dim=1)

        # Forward pass through the fully connected layers
        output1 = self.fc_layer1(features)
        output2 = self.fc_layer2(output1)
        output3 = self.fc_layer3(output2)
        output4 = self.fc_layer4(output3)
        output = self.fc_layer5(output4)

        return output

In [None]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from google.colab import drive

drive.mount('/content/drive')

class SPOTDataLoader(Dataset):
    def __init__(self, root_dir, goal_folder, labels_file, transform=None):
        self.root_dir = root_dir
        self.goal_folder = goal_folder
        self.transform = transform
        self.labels = np.load(labels_file)

        if torch.cuda.is_available():
            self.cuda = True
        else:
            self.cuda = False

    def __len__(self):
        return self.labels.shape[0]

    def __getitem__(self, idx):
        folder_name = format(idx, '05d')
        folder_path = os.path.join(self.root_dir, folder_name)

        input_images = []
        for i in range(5):
            input_image_path = os.path.join(folder_path, f"{i}.png")
            image = Image.open(input_image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            input_images.append(image)

        goal_images = []
        goal_folder_path = os.path.join(self.root_dir, self.goal_folder)
        for i in range(5):
            goal_image_path = os.path.join(goal_folder_path, f"{i}.png")
            image = Image.open(goal_image_path).convert('RGB')
            if self.transform:
                image = self.transform(image)
            goal_images.append(image)

        label = self.labels[idx]

        if self.cuda is True:
            input_images = torch.stack(input_images, dim=0).cuda()
            goal_images = torch.stack(goal_images, dim=0).cuda()
            label_tensor = torch.tensor(label).cuda()
        else:
            input_images = torch.stack(input_images, dim=0)
            goal_images = torch.stack(goal_images, dim=0)
            label_tensor = torch.tensor(label)

        return input_images, goal_images, label_tensor

Mounted at /content/drive


In [None]:
import os
import numpy as np

DATASET_INITIAL_PATH = '/content/drive/MyDrive/Spot_IL/CLIPSeg_Mixed_Dataset_Train_Test'
TRAIN_PATH = DATASET_INITIAL_PATH + '/train/'
TEST_PATH = DATASET_INITIAL_PATH + '/test/'

print(TEST_PATH)
print(TRAIN_PATH)

def quaternion_to_radians(quaternion):
    qx, qy, qz, qw = quaternion[-4:]
    yaw = np.arctan2(2 * (qw * qz + qx * qy), 1 - 2 * (qy**2 + qz**2))
    return yaw

def convert_labels_to_radians(path):
    labels_path = os.path.join(path, 'labels.npy')
    radians_labels_path = os.path.join(path, 'labels_radians.npy')
    if os.path.exists(labels_path):
        labels = np.load(labels_path)
        radians_labels = [quaternion_to_radians(label) for label in labels]
        np.save(radians_labels_path, radians_labels)
        print(f"Converted quaternions to radians and saved to {labels_path}")

convert_labels_to_radians(TRAIN_PATH)
convert_labels_to_radians(TEST_PATH)

/content/drive/MyDrive/Spot_IL/CLIPSeg_Mixed_Dataset_Train_Test/test/
/content/drive/MyDrive/Spot_IL/CLIPSeg_Mixed_Dataset_Train_Test/train/
Converted quaternions to radians and saved to /content/drive/MyDrive/Spot_IL/CLIPSeg_Mixed_Dataset_Train_Test/train/labels.npy
Converted quaternions to radians and saved to /content/drive/MyDrive/Spot_IL/CLIPSeg_Mixed_Dataset_Train_Test/test/labels.npy


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os


WEIGHT_SAVING_STEP = 10
DPI = 120
FIGURE_SIZE_PIXEL = [2490, 1490]
FIGURE_SIZE = [fsp / DPI for fsp in FIGURE_SIZE_PIXEL]

def plot_graph(training_losses, accuracies, figure_path=None, fold=0, start_plot=0, end_plot=0):

    if start_plot == end_plot:
        return

    # Fill with zero
    for i in range(start_plot):
        training_losses[i] = [0, 0]
        accuracies[i] = [0, 0]

    # Plot Training Loss
    training_loss = [data[0] for data in training_losses]
    average_loss = [data[1] for data in training_losses]

    plt.figure(figsize=FIGURE_SIZE, dpi=DPI)
    plt.scatter(range(start_plot + 1, end_plot + 1), training_loss[start_plot:], color='blue', label='Training Loss')
    plt.plot(range(start_plot + 1, end_plot + 1), average_loss[start_plot:], color='cyan', linestyle='-', label='Average Training Loss')
    plt.title(f"Fold {fold} Training Loss")
    plt.xlabel("Epoches")
    plt.ylabel("Loss (1000 radians)")
    plt.legend()

    lowest_loss = training_loss[0]
    for i in range(end_plot):

        if training_loss[i] < lowest_loss:
            lowest_loss = training_loss[i]

        if ((i + 1) % WEIGHT_SAVING_STEP) == 0:
            plt.annotate(str(round(training_loss[i], 6)), xy=((i + 1), training_loss[i]))

    plt.annotate(str(round(training_loss[end_plot - 1], 6)), xy=(end_plot, training_loss[end_plot - 1]))

    plt.text(0, plt.gca().get_ylim()[1], f'Lowest Loss: {lowest_loss: .6f}')

    if figure_path is not None:
        plt.savefig(figure_path + f'Fold_{fold}_Training_loss.png')
        plt.close()

    else:
        plt.show()

    # Plot Accuracy
    train_accuracy = [data[0] for data in accuracies]
    valid_accuracy = [data[1] for data in accuracies]

    plt.figure(figsize=FIGURE_SIZE, dpi=DPI)
    plt.plot(range(start_plot + 1, end_plot + 1), train_accuracy[start_plot:], color='blue', linestyle='-', marker='o', label='Training Accuracy')
    plt.plot(range(start_plot + 1, end_plot + 1), valid_accuracy[start_plot:], color='orange', linestyle='-', marker='o', label='Validation Accuracy')
    plt.title(f"Fold {fold} Accuracy")
    plt.xlabel("Epoches")
    plt.ylabel("Acurracy (%)")
    plt.legend()

    for i in range(end_plot):
        if ((i + 1) % WEIGHT_SAVING_STEP) == 0:
            plt.annotate(str(round(train_accuracy[i], 2)), xy=((i + 1), train_accuracy[i]))
            plt.annotate(str(round(valid_accuracy[i], 2)), xy=((i + 1), valid_accuracy[i]))
    plt.annotate(str(round(train_accuracy[end_plot - 1], 2)), xy=(end_plot, train_accuracy[end_plot - 1]))
    plt.annotate(str(round(valid_accuracy[end_plot - 1], 2)), xy=(end_plot, valid_accuracy[end_plot - 1]))

    if figure_path is not None:
        plt.savefig(figure_path + f'Fold_{fold}_Accuracy.png')
        plt.close()

    else:
        plt.show()

In [None]:
from torch.utils.data import DataLoader
import torch, os
from torchvision import transforms
import numpy as np
from sklearn.model_selection import KFold

# Setup Destination
DATASET_NAME = 'mixed'
DATASET_INITIAL_PATH = '/content/drive/MyDrive/Spot_IL/CLIPSeg_Mixed_Dataset_Train_Test'
WEIGHT_FOLDER_NAME = 'lr1e-5_with_scaling'

# Paths
TRAIN_PATH = os.path.join(DATASET_INITIAL_PATH, 'train/')
TEST_PATH = os.path.join(DATASET_INITIAL_PATH, 'test/')
GOAL_PATH = os.path.join(DATASET_INITIAL_PATH, 'goal/goal_images/')
LABEL_PATH = os.path.join(TRAIN_PATH, 'labels_radians.npy')

print(TRAIN_PATH)
print(TEST_PATH)
print(GOAL_PATH)
print(LABEL_PATH)

# Output Paths
WEIGHT_PATH = os.path.join(DATASET_INITIAL_PATH,'train', f'weights/Train_FiveResNet18MLP5_{DATASET_NAME}', WEIGHT_FOLDER_NAME)
FIGURE_PATH = os.path.join(DATASET_INITIAL_PATH, 'train',f'Results/Train_FiveResNet18MLP5_{DATASET_NAME}', WEIGHT_FOLDER_NAME)

# Ensure directories exist
os.makedirs(WEIGHT_PATH, exist_ok=True)
os.makedirs(FIGURE_PATH, exist_ok=True)

# KFold Parameters
NUM_FOLD = 5
k_fold = KFold(NUM_FOLD, shuffle=True)
CONTINUE = [0] * NUM_FOLD   # Start from beginning, use 0

# Preprocess for images
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

if torch.cuda.is_available():

    train_dataset = SPOTDataLoader(
        root_dir = TRAIN_PATH,
        goal_folder = GOAL_PATH,
        labels_file = LABEL_PATH,
        transform = data_transforms
    )
    DEVICE = 'cuda'
    print('Cuda')

else:
    train_dataset = SPOTDataLoader(
        root_dir = TRAIN_PATH,
        goal_folder = GOAL_PATH,
        labels_file = LABEL_PATH,
        transform = data_transforms
    )
    DEVICE = 'cpu'
    print('CPU')

# Hyper Parameters
loss_fn = torch.nn.MSELoss()
BATCH_SIZE = 2
LEARNING_RATE = 1e-5

# Training Parameters
WEIGHT_SAVING_STEP = 10
LOSS_SCALE = 1e3

# Validation Parameter
TOLERANCE = 1e-4

# Saving Hyper Param
hyper_params_path = WEIGHT_PATH + 'hyper_params'
hyper_params = {'NUM_FOLD': NUM_FOLD, 'BATCH_SIZE': BATCH_SIZE, 'LEARNING_RATE': LEARNING_RATE, 'LOSS_SCALE': LOSS_SCALE, 'TOLERANCE': TOLERANCE}
np.savez(hyper_params_path, **hyper_params)

for fold, (train_ids, valid_ids) in enumerate(k_fold.split(train_dataset)):
    if fold < 2:
        continue

    print(f'FOLD {fold}')
    fold_path = WEIGHT_PATH + 'fold_' + str(fold) + '/'
    if not os.path.exists(fold_path):
        os.mkdir(fold_path)

    # Setup Model
    model = FiveResNet18MLP5().to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Tracking Parameters
    epoch = CONTINUE[fold] + 1
    training_loss = 1e6
    training_total_loss = 0
    training_losses = []   #[training_loss training_average_loss]
    tracking_losses_path = fold_path + 'training_losses.npy'
    accuracies = []   #[train_accuracy valid_accuracy]
    accuracies_path = fold_path + 'accuracies.npy'

    if CONTINUE[fold] > 1:
        model.load_state_dict(torch.load(fold_path + 'epoch_' + str(CONTINUE[fold]) + '.pth'))
        print('Weight Loaded!')
        training_losses = list(np.load(tracking_losses_path))[:CONTINUE[fold]]
        accuracies = list(np.load(accuracies_path))[:CONTINUE[fold]]
        print(f'Fold {fold} Parameter Loaded!')

    train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
    valid_subsampler = torch.utils.data.SubsetRandomSampler(valid_ids)

    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_subsampler, num_workers=0)
    valid_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=valid_subsampler, num_workers=0)

    print(f"Number of batches in train_dataloader: {len(train_dataloader)}")

    # Test accessing the first batch
    current_images, goal_images, labels = next(iter(train_dataloader))
    print(f"Batch shapes: {current_images.shape}, {goal_images.shape}, {labels.shape}")


    # Train Model
    model.train()
    while training_loss > ((TOLERANCE ** 2) * LOSS_SCALE):

        running_loss = 0.0

        for current_images, goal_images, labels in train_dataloader:

            optimizer.zero_grad()
            output = model(current_images, goal_images)
            # print(output.flatten())
            # print(labels.float())
            loss = loss_fn(output.flatten(), labels.float()) * LOSS_SCALE

            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        training_loss = running_loss / len(train_dataloader)

        # Moving Average
        training_total_loss += training_loss * 5
        training_average_loss = training_total_loss / (len(training_losses) + 5)
        training_total_loss = training_average_loss * (len(training_losses) + 1)

        # Save training loss
        training_losses.append([training_loss, training_average_loss])
        print(f'Epoch {epoch}, Loss: {training_losses[epoch - 1][0]:.6f}, Average Loss: {training_losses[epoch - 1][1]:.6f}', end='; ')
        np.save(tracking_losses_path, training_losses)

        if (epoch % WEIGHT_SAVING_STEP) == 0:
            torch.save(model.state_dict(), (fold_path + 'epoch_' + str(epoch) + '.pth'))
            print('Save Weights', end='; ')

        # Valid Model
        model.eval()
        with torch.no_grad():

            num_correct, num_total = 0, 0
            for current_images, goal_images, labels in train_dataloader:
                output = model(current_images, goal_images)
                for i in range(len(output)):
                    loss = abs(output[i] - labels[i]).item()
                    num_total += 1
                    if loss < TOLERANCE:
                        num_correct += 1
            train_accuracy = (num_correct / num_total) * 100

            num_correct, num_total = 0, 0
            for current_images, goal_images, labels in valid_dataloader:
                output = model(current_images, goal_images)
                for i in range(len(output)):
                    loss = abs(output[i] - labels[i]).item()
                    num_total += 1
                    if loss < TOLERANCE:
                        num_correct += 1
            valid_accuracy = (num_correct / num_total) * 100

            accuracies.append([train_accuracy, valid_accuracy])
            print(f'Train Accuracy {accuracies[epoch - 1][0]:.2f}%, Valid Accuracy: {accuracies[epoch - 1][1]:.2f}%')
            np.save(accuracies_path, accuracies)

            epoch += 1

    print(f'Finished Training fold {fold}')
    epoch -= 1

    # Save last weight
    torch.save(model.state_dict(), (WEIGHT_PATH + 'fold_' + str(fold) + '/epoch_' + str(epoch) + '.pth'))
    print('Save Last Weights')

    # Plot Training Loss and Accuracies graphs
    plot_graph(training_losses, accuracies, FIGURE_PATH, fold, end_plot=epoch)

'\nfrom torch.utils.data import DataLoader\nimport torch, os\nfrom torchvision import transforms\nimport numpy as np\nfrom sklearn.model_selection import KFold\n\n# Setup Destination\nDATASET_NAME = \'mixed\'\nDATASET_INITIAL_PATH = \'/content/drive/MyDrive/Spot_IL/CLIPSeg_Mixed_Dataset_Train_Test\'\nWEIGHT_FOLDER_NAME = \'lr1e-5_with_scaling\'\n\n# Paths\nTRAIN_PATH = os.path.join(DATASET_INITIAL_PATH, \'train/\')\nTEST_PATH = os.path.join(DATASET_INITIAL_PATH, \'test/\')\nGOAL_PATH = os.path.join(DATASET_INITIAL_PATH, \'goal/goal_images/\')\nLABEL_PATH = os.path.join(TRAIN_PATH, \'labels_radians.npy\')\n\nprint(TRAIN_PATH)\nprint(TEST_PATH)\nprint(GOAL_PATH)\nprint(LABEL_PATH)\n\n# Output Paths\nWEIGHT_PATH = os.path.join(DATASET_INITIAL_PATH,\'train\', f\'weights/Train_FiveResNet18MLP5_{DATASET_NAME}\', WEIGHT_FOLDER_NAME)\nFIGURE_PATH = os.path.join(DATASET_INITIAL_PATH, \'train\',f\'Results/Train_FiveResNet18MLP5_{DATASET_NAME}\', WEIGHT_FOLDER_NAME)\n\n# Ensure directories exist\

In [None]:
from torch.utils.data import DataLoader
import torch, os
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt

# Parameters
WEIGHT_DATASET_NAME = 'mixed'
TEST_DATASET_NAME = 'mixed'
TEST_DATASET_TYPE = 'test'
FOLD = 0
WEIGHT_NAME = 'epoch_297.pth'

# Constants
DATASET_PATH = '/content/drive/MyDrive/Spot_IL/CLIPSeg_Mixed_Dataset_Train_Test'
TEST_PATH = os.path.join(DATASET_PATH, TEST_DATASET_TYPE)
GOAL_PATH = os.path.join(DATASET_PATH, 'goal', 'goal_images')
LABEL_PATH = os.path.join(TEST_PATH, 'labels_radians.npy')

# Output Paths
WEIGHT_PATH = os.path.join(DATASET_PATH, 'test', 'weights', f'FiveResNet18MLP5_{WEIGHT_DATASET_NAME}', 'lr1e-5_with_scaling', f'fold_{FOLD}')
FIGURES_PATH = os.path.join(DATASET_PATH, 'test', 'Results', f'FiveResNet18MLP5_{WEIGHT_DATASET_NAME}', 'lr1e-5_with_scaling', f'fold_{FOLD}')

# Ensure directories exist
os.makedirs(WEIGHT_PATH, exist_ok=True)
os.makedirs(FIGURES_PATH, exist_ok=True)

DPI = 120
FIGURE_SIZE_PIXEL = [2490, 1490]
FIGURE_SIZE = [fsp / DPI for fsp in FIGURE_SIZE_PIXEL]

def test_model(test_dataset, model, weight_name, device='cuda', draw=False, show=False):

    model.to(device)
    # model.load_state_dict(torch.load(os.path.join(WEIGHT_PATH, weight_name)))
    model.eval()

    test_dataloader = DataLoader(test_dataset, batch_size=1)
    results = np.empty([0, 3])

    with torch.no_grad():
        idx = 0
        for current_images, goal_images, label in test_dataloader:
            output = model(current_images, goal_images)
            output_degree = (output.item() / np.pi) * 180
            label_degree = (label.item() / np.pi) * 180
            loss = abs(label - output)

            iteration_result = np.array([output_degree, label_degree, loss.item()])
            results = np.vstack([results, iteration_result])

            print(idx, iteration_result)
            idx += 1

    # Plot
    if draw is True:
        plt.figure(figsize=FIGURE_SIZE, dpi=DPI)
        plt.plot(range(len(test_dataloader)), results[:, 0], color='green', linestyle='-', label='Predicted Rotation Angle')
        plt.plot(range(len(test_dataloader)), results[:, 1], color='cyan', linestyle='-', label='GT Rotation Angle')
        plt.plot(range(len(test_dataloader)), results[:, 2], color='blue', linestyle='-', label='Difference')
        plt.title(f'Test for {weight_name}')
        plt.xlabel("Datapoint")
        plt.ylabel("Degree")
        plt.legend()

        if show is True:
            plt.show()
        else:
            file_name = FIGURES_PATH + f'{weight_name.split(".pth")[0]}_test_with_{TEST_DATASET_NAME}_{TEST_DATASET_TYPE}'
            plt.savefig(file_name + '.png')
            plt.close()

            np.savetxt(file_name + '.csv', results, delimiter=',')

if __name__ == '__main__':

    model = FiveResNet18MLP5()

    # Preprocess for images
    data_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    if torch.cuda.is_available():

        test_dataset = SPOTDataLoader(
            root_dir = TEST_PATH,
            goal_folder = GOAL_PATH,
            labels_file = LABEL_PATH,
            transform = data_transforms
        )
        DEVICE = 'cuda'
        print('Cuda')

    else:
        test_dataset = SPOTDataLoader(
            root_dir = TEST_PATH,
            goal_folder = GOAL_PATH,
            labels_file = LABEL_PATH,
            transform = data_transforms
        )
        DEVICE = 'cpu'
        print('CPU')

    weight_name = WEIGHT_NAME
    test_model(test_dataset, model, weight_name, device=DEVICE, draw=True, show=False)

Cuda
0 [0.33789298 0.97402825 0.01110265]
1 [ 0.52705747 -0.97402825  0.02619889]
2 [ 0.32813896 -0.97402825  0.02272711]
3 [ 0.54641215 -0.97402825  0.02653669]
4 [0.61098438 0.97402825 0.00633631]
5 [ 0.37340716 -0.97402825  0.02351718]
6 [ 0.10473634 -0.97402825  0.01882799]
7 [ 0.48596785 -0.97402825  0.02548174]
8 [0.44158646 0.97402825 0.00929286]
9 [0.36719823 0.97402825 0.01059118]
10 [0.56791438 0.97402825 0.00708802]
11 [0.44859407 0.97402825 0.00917056]
12 [0.50297646 0.97402825 0.0082214 ]
13 [0.1888681  0.97402825 0.01370363]
14 [ 0.29718142 -0.97402825  0.02218679]
15 [ 0.27508181 -0.97402825  0.02180108]
16 [ 0.23554014 -0.97402825  0.02111095]
17 [ 0.53648804 -0.97402825  0.02636348]
18 [0.45839577 0.97402825 0.00899948]
19 [ 0.06939491 -0.97402825  0.01821117]
20 [0.17662015 0.97402825 0.0139174 ]
21 [0.54186276 0.97402825 0.00754271]
22 [0.49429375 0.97402825 0.00837295]
23 [ 0.54661897 -0.97402825  0.0265403 ]
24 [ 0.3117992  -0.97402825  0.02244192]
25 [0.52706329 0