In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
import torchvision.transforms.functional as TF
import torchvision as tv

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm

from torch.cuda.amp import autocast, GradScaler
import plotly.graph_objects as go
import seaborn as sns
import plotly.express as px
import pandas as pd

In [27]:
X_array =  np.load('X_array.npy')
Y_array =  np.load('Y_array.npy')


In [28]:
X_train, X_test, y_train, y_test = train_test_split(X_array, Y_array, test_size=0.2, random_state=42)

In [29]:
X_train = torch.tensor(X_train)
y_train = torch.tensor(y_train)
X_test = torch.tensor(X_test)
y_test = torch.tensor(y_test)

In [30]:
print(len(X_train))
print(len(y_train))
print(len(X_test))
print(len(y_test))

359
359
90
90


In [31]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        # Добавляем каналы (если они отсутствуют)
        if x.dim() == 3:  # Проверяем, что входной тензор имеет размерность [N, H, W]
            x = x.unsqueeze(1)  # Добавляем размерность каналов
            #x = torch.cat([x, x, x], dim=1)  # Дублируем каналы, чтобы получить RGB изображение
        # Интерполируем изображения до размера 224x224
        self.x = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

# Создаем Dataset и DataLoader
dataset_train = CustomDataset(X_train, y_train)
dataset_test = CustomDataset(X_test, y_test)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=64, shuffle=True)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=64, shuffle=False)

In [32]:
for x, y in dataloader_train:
    print(x.size())

torch.Size([64, 1, 224, 224])
torch.Size([64, 1, 224, 224])
torch.Size([64, 1, 224, 224])
torch.Size([64, 1, 224, 224])
torch.Size([64, 1, 224, 224])
torch.Size([39, 1, 224, 224])


In [33]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)
model.eval()

Using cache found in C:\Users\lera-/.cache\torch\hub\pytorch_vision_v0.10.0

The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.



AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [34]:
# Заменяем последний полносвязный слой на новый слой, который будет соответствовать вашему набору данных
num_classes = 10  # количество классов в вашем наборе данных
model.features[0] = nn.Conv2d(1, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
model.classifier[6] = nn.Linear(4096, num_classes)

In [35]:
for param in model.parameters():
    param.data = param.data.double()


In [36]:
# Определение функции потерь и оптимизатора
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999))


In [37]:
def accuracy(pred, label, threshold=0.5):
    pred = torch.sigmoid(pred)
    pred_labels = (pred > threshold).float()  # Threshold the predictions
    correct = (pred_labels == label).sum().item()  # Compare predictions with labels
    total = label.size(0) * label.size(1)  # Total number of labels
    return correct / total

In [38]:
num_epochs = 30  # количество эпох для обучения

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    acc_val = 0.0
    
    for matrix, labels in tqdm(dataloader_train, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        matrix, labels = matrix.to(device), labels.to(device)

        optimizer.zero_grad()
        
        outputs = model(matrix)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        loss_item = loss.item()
        running_loss += loss_item
        
        
        acc_current = accuracy(outputs.cpu().float(), labels.cpu().float())
        acc_val += acc_current

      
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader_train):.3f}, Accuracy: {acc_val/len(dataloader_train):.3f}") 

                                                          

Epoch [1/100], Loss: 0.612, Accuracy: 0.740


                                                          

Epoch [2/100], Loss: 0.477, Accuracy: 0.787


                                                          

Epoch [3/100], Loss: 0.443, Accuracy: 0.808


                                                          

Epoch [4/100], Loss: 0.406, Accuracy: 0.829


                                                          

Epoch [5/100], Loss: 0.383, Accuracy: 0.834


                                                          

Epoch [6/100], Loss: 0.369, Accuracy: 0.839


                                                          

Epoch [7/100], Loss: 0.340, Accuracy: 0.849


                                                          

Epoch [8/100], Loss: 0.320, Accuracy: 0.858


                                                          

Epoch [9/100], Loss: 0.300, Accuracy: 0.870


                                                           

Epoch [10/100], Loss: 0.283, Accuracy: 0.877


                                                           

Epoch [11/100], Loss: 0.258, Accuracy: 0.892


                                                           

Epoch [12/100], Loss: 0.247, Accuracy: 0.889


                                                           

Epoch [13/100], Loss: 0.244, Accuracy: 0.901


                                                           

Epoch [14/100], Loss: 0.222, Accuracy: 0.905


                                                           

Epoch [15/100], Loss: 0.227, Accuracy: 0.907


                                                           

Epoch [16/100], Loss: 0.210, Accuracy: 0.915


                                                           

Epoch [17/100], Loss: 0.207, Accuracy: 0.910


                                                           

Epoch [18/100], Loss: 0.204, Accuracy: 0.911


                                                           

Epoch [19/100], Loss: 0.191, Accuracy: 0.918


                                                           

Epoch [20/100], Loss: 0.181, Accuracy: 0.924


                                                           

Epoch [21/100], Loss: 0.166, Accuracy: 0.931


                                                           

Epoch [22/100], Loss: 0.159, Accuracy: 0.930


                                                           

Epoch [23/100], Loss: 0.152, Accuracy: 0.934


                                                           

Epoch [24/100], Loss: 0.143, Accuracy: 0.941


                                                           

Epoch [25/100], Loss: 0.141, Accuracy: 0.937


                                                           

Epoch [26/100], Loss: 0.144, Accuracy: 0.934


                                                           

Epoch [27/100], Loss: 0.139, Accuracy: 0.942


                                                           

Epoch [28/100], Loss: 0.139, Accuracy: 0.938


                                                           

Epoch [29/100], Loss: 0.135, Accuracy: 0.943


                                                           

Epoch [30/100], Loss: 0.144, Accuracy: 0.939


                                                           

Epoch [31/100], Loss: 0.140, Accuracy: 0.938


                                                           

Epoch [32/100], Loss: 0.125, Accuracy: 0.949


                                                           

Epoch [33/100], Loss: 0.130, Accuracy: 0.942


                                                           

Epoch [34/100], Loss: 0.126, Accuracy: 0.950


                                                           

Epoch [35/100], Loss: 0.117, Accuracy: 0.950


                                                           

Epoch [36/100], Loss: 0.113, Accuracy: 0.950


                                                           

Epoch [37/100], Loss: 0.115, Accuracy: 0.952


                                                           

Epoch [38/100], Loss: 0.112, Accuracy: 0.950


                                                           

Epoch [39/100], Loss: 0.109, Accuracy: 0.952


                                                           

Epoch [40/100], Loss: 0.105, Accuracy: 0.955


                                                           

Epoch [41/100], Loss: 0.112, Accuracy: 0.949


                                                           

Epoch [42/100], Loss: 0.110, Accuracy: 0.955


                                                           

Epoch [43/100], Loss: 0.114, Accuracy: 0.952


                                                           

Epoch [44/100], Loss: 0.128, Accuracy: 0.946


                                                           

Epoch [45/100], Loss: 0.116, Accuracy: 0.946


                                                           

Epoch [46/100], Loss: 0.113, Accuracy: 0.950


                                                           

Epoch [47/100], Loss: 0.099, Accuracy: 0.958


                                                           

Epoch [48/100], Loss: 0.108, Accuracy: 0.953


                                                           

Epoch [49/100], Loss: 0.107, Accuracy: 0.953


                                                           

Epoch [50/100], Loss: 0.091, Accuracy: 0.964


                                                           

Epoch [51/100], Loss: 0.089, Accuracy: 0.961


                                                           

Epoch [52/100], Loss: 0.087, Accuracy: 0.964


                                                           

Epoch [53/100], Loss: 0.088, Accuracy: 0.963


                                                           

Epoch [54/100], Loss: 0.085, Accuracy: 0.961


                                                           

Epoch [55/100], Loss: 0.091, Accuracy: 0.963


                                                           

Epoch [56/100], Loss: 0.087, Accuracy: 0.963


                                                           

Epoch [57/100], Loss: 0.085, Accuracy: 0.963


                                                           

Epoch [58/100], Loss: 0.082, Accuracy: 0.963


                                                           

Epoch [59/100], Loss: 0.074, Accuracy: 0.968


                                                           

Epoch [60/100], Loss: 0.074, Accuracy: 0.966


                                                           

Epoch [61/100], Loss: 0.070, Accuracy: 0.969


                                                           

Epoch [62/100], Loss: 0.085, Accuracy: 0.963


                                                           

Epoch [63/100], Loss: 0.078, Accuracy: 0.965


                                                           

Epoch [64/100], Loss: 0.075, Accuracy: 0.967


                                                           

Epoch [65/100], Loss: 0.070, Accuracy: 0.973


                                                           

Epoch [66/100], Loss: 0.079, Accuracy: 0.967


                                                           

Epoch [67/100], Loss: 0.087, Accuracy: 0.964


                                                           

Epoch [68/100], Loss: 0.084, Accuracy: 0.964


                                                           

Epoch [69/100], Loss: 0.080, Accuracy: 0.964


                                                           

Epoch [70/100], Loss: 0.069, Accuracy: 0.970


                                                           

Epoch [71/100], Loss: 0.066, Accuracy: 0.969


                                                           

Epoch [72/100], Loss: 0.071, Accuracy: 0.968


                                                           

Epoch [73/100], Loss: 0.081, Accuracy: 0.966


                                                           

Epoch [74/100], Loss: 0.102, Accuracy: 0.963


                                                           

Epoch [75/100], Loss: 0.115, Accuracy: 0.956


                                                           

Epoch [76/100], Loss: 0.097, Accuracy: 0.963


                                                           

Epoch [77/100], Loss: 0.101, Accuracy: 0.958


                                                           

Epoch [78/100], Loss: 0.087, Accuracy: 0.963


                                                           

Epoch [79/100], Loss: 0.088, Accuracy: 0.965


                                                           

Epoch [80/100], Loss: 0.091, Accuracy: 0.962


                                                           

Epoch [81/100], Loss: 0.083, Accuracy: 0.964


                                                           

Epoch [82/100], Loss: 0.075, Accuracy: 0.968


                                                           

Epoch [83/100], Loss: 0.069, Accuracy: 0.969


                                                           

Epoch [84/100], Loss: 0.072, Accuracy: 0.968


                                                           

Epoch [85/100], Loss: 0.074, Accuracy: 0.965


                                                           

Epoch [86/100], Loss: 0.068, Accuracy: 0.970


                                                           

Epoch [87/100], Loss: 0.061, Accuracy: 0.973


                                                           

Epoch [88/100], Loss: 0.058, Accuracy: 0.973


                                                           

Epoch [89/100], Loss: 0.059, Accuracy: 0.974


                                                           

Epoch [90/100], Loss: 0.057, Accuracy: 0.972


                                                           

Epoch [91/100], Loss: 0.056, Accuracy: 0.978


                                                           

Epoch [92/100], Loss: 0.052, Accuracy: 0.977


                                                           

Epoch [93/100], Loss: 0.055, Accuracy: 0.977


                                                           

Epoch [94/100], Loss: 0.053, Accuracy: 0.975


                                                           

Epoch [95/100], Loss: 0.047, Accuracy: 0.981


                                                           

Epoch [96/100], Loss: 0.049, Accuracy: 0.978


                                                           

Epoch [97/100], Loss: 0.051, Accuracy: 0.976


                                                           

Epoch [98/100], Loss: 0.053, Accuracy: 0.975


                                                           

Epoch [99/100], Loss: 0.050, Accuracy: 0.978


                                                            

Epoch [100/100], Loss: 0.053, Accuracy: 0.975




In [39]:
loss_val = 0.0
acc_val = 0.0
for sample in tqdm(dataloader_test):
    matrix, label = sample[0].to(device), sample[1].to(device)

    # Добавим измерение каналов к данным
    #matrix = matrix.unsqueeze(1)
  

    pred = model(matrix)
    loss = criterion(pred, label)

    loss_item = loss.item()
    loss_val += loss_item


    acc_current = accuracy(pred.cpu().float(), label.cpu().float())
    acc_val += acc_current


print(f'Loss: {loss_val/len(dataloader_test):.5f}, Accuracy: {acc_val/len(dataloader_test):.3f}')

100%|██████████| 2/2 [00:01<00:00,  1.33it/s]

Loss: 1.78065, Accuracy: 0.748





In [40]:
def aroma_map_comparison(matrix, labels):
    matrix = matrix.unsqueeze(1)
    pred = model(matrix).to('cpu')
    labels = labels.to('cpu')

    df = pd.DataFrame(dict(
        r=torch.sigmoid(pred).detach().numpy()[0],
        theta=['Herbs and spices', 'Tobacco/Smoke', 'Wood', 'Berries', 'Citrus',
       'Fruits ', 'Nuts', 'Coffee', 'Chocolate/Cacao', 'Flowers']))
    
    df['Label'] = 'Predict'
    
    df1 = pd.DataFrame(dict(
        r=labels.detach().numpy(),
        theta=['Herbs and spices', 'Tobacco/Smoke', 'Wood', 'Berries', 'Citrus',
       'Fruits ', 'Nuts', 'Coffee', 'Chocolate/Cacao', 'Flowers']))
    
    df1['Label'] = 'Experiment'
    
    fig = px.line_polar(pd.concat([df, df1]), color="Label",  r='r', theta='theta', line_close=True, line_shape='linear', color_discrete_sequence=['#008080', '#FFC0CB'],
                    template="plotly_dark")
    

    
    fig.update_traces(fill='toself')
    fig.update_layout(polar=dict(radialaxis=dict(visible=True, range=[0, 1])))
    fig.show()

In [41]:
for sample in dataset_test:
    matrix, label = sample[0].to(device), sample[1].to(device)
    aroma_map_comparison(matrix, label)

: 