# Wine aroma prediction using ResNet 
Here we use ResNet to predict wine aroma from the previously obtained wine composition matrices

In [None]:
#Import the necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import torchvision as tv
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
from torch.cuda.amp import autocast, GradScaler
import plotly.graph_objects as go
import seaborn as sns
import plotly.express as px
import pandas as pd
from torchvision.transforms import functional as TF

## 1. Working with data
At this stage, we load the previously prepared data, and process the data for further neural networking 

In [14]:
# Loading data
X_array =  np.load('X_array.npy')
Y_array =  np.load('Y_array.npy')

In [15]:
# Separation of data into training, validation and test data in the ratio of 70:20:10
X_train, X_test, y_train, y_test = train_test_split(X_array, Y_array, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

In [16]:
# Conversion to a tensor
X_train = torch.tensor(X_train)
y_train = torch.tensor(y_train)
X_val = torch.tensor(X_val)
y_val = torch.tensor(y_val)
X_test = torch.tensor(X_test)
y_test = torch.tensor(y_test)

In [17]:
print(len(X_train))
print(len(y_train))
print(len(X_val))
print(len(y_val))
print(len(X_test))
print(len(y_test))

323
323
81
81
45
45


Since ResNet has been trained on 224x224 images, we need to bring our images to the same size by padding (adding white pixels in equal proportions to the bottom, top and sides of the image), a method chosen to preserve chemical information without distorting it through interpolation

In [20]:
# Create a custom dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        # Add channels if there are none (i.e. if the images have dimension [N, H, W]).
        if x.dim() == 3:  
            x = x.unsqueeze(1)  # We add channel dimensionality [N, 1, H, W]

        self.x = self.pad_to_224(x)  # Adding padding
        self.y = y

    def pad_to_224(self, x):
        padded_images = []
        for img in x:
            # We check that the input images are of size [1, 44, 100]
            _, h, w = img.shape
            assert h == 44 and w == 100, "Incorrect size of the input image"

            # Applying padding
            img_padded = TF.pad(img, (62, 90), fill=255)  # White background

            padded_images.append(img_padded)

        return torch.stack(padded_images)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

# Create Dataset and DataLoader
dataset_train = CustomDataset(X_train, y_train)
dataset_val = CustomDataset(X_val, y_val)
dataset_test = CustomDataset(X_test, y_test)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=32, shuffle=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=32, shuffle=True)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=32, shuffle=False)


In [21]:
for x, y in dataloader_train:
    print(x.size())

torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([3, 1, 224, 224])


## 3. Loading of the pre-trained model and its additional training
At this stage we will load the pre-trained model, change the necessary parameters, and retrain it on our data

In [22]:
# Loading a pre-trained model
model_resnet = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
model_resnet.eval()

Using cache found in C:\Users\V/.cache\torch\hub\pytorch_vision_v0.10.0


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
# Replace the last full-link layer with a new layer that will match your dataset
num_classes = 10  
model_resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model_resnet.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True)
model_resnet

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [24]:
# Definition of loss-function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model_resnet.parameters(), lr=0.0001, betas=(0.9, 0.999))


In [25]:
# Accuracy function
def accuracy(pred, label, threshold=0.5):
    pred = torch.sigmoid(pred)
    pred_labels = (pred > threshold).float()  # Threshold the predictions
    correct = (pred_labels == label).sum().item()  # Compare predictions with labels
    total = label.size(0) * label.size(1)  # Total number of labels
    return correct / total

In [None]:
# Transferring the neural network to cuda
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_resnet.to(device)

In [None]:
# Model retraining
num_epochs = 100 

for epoch in range(num_epochs):
    running_loss = 0.0
    acc_val = 0.0
    
    for matrix, labels in tqdm(dataloader_train, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        # Convert the data to the required format (float32) and move it to the device
        matrix, labels = matrix.float().to(device), labels.to(device)

        optimizer.zero_grad()
        
        outputs = model_resnet(matrix)

        # Calculate the error
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        
        # Accuracy 
        acc_current = accuracy(outputs.cpu().float(), labels.cpu().float())
        acc_val += acc_current

    # Output the average loss and accuracy over the epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader_train):.3f}, Accuracy: {acc_val/len(dataloader_train):.3f}")

                                                            

Epoch [1/100], Loss: 1.342, Accuracy: 0.697


                                                            

Epoch [2/100], Loss: 0.495, Accuracy: 0.770


                                                            

Epoch [3/100], Loss: 0.481, Accuracy: 0.774


                                                            

Epoch [4/100], Loss: 0.496, Accuracy: 0.771


                                                            

Epoch [5/100], Loss: 0.470, Accuracy: 0.781


                                                            

Epoch [6/100], Loss: 0.473, Accuracy: 0.776


                                                            

Epoch [7/100], Loss: 0.480, Accuracy: 0.772


                                                            

Epoch [8/100], Loss: 0.483, Accuracy: 0.775


                                                            

Epoch [9/100], Loss: 0.480, Accuracy: 0.779


                                                             

Epoch [10/100], Loss: 0.474, Accuracy: 0.772


                                                             

Epoch [11/100], Loss: 0.464, Accuracy: 0.783


                                                             

Epoch [12/100], Loss: 0.471, Accuracy: 0.772


                                                             

Epoch [13/100], Loss: 0.470, Accuracy: 0.776


                                                             

Epoch [14/100], Loss: 0.478, Accuracy: 0.780


                                                             

Epoch [15/100], Loss: 0.497, Accuracy: 0.768


                                                             

Epoch [16/100], Loss: 0.484, Accuracy: 0.775


                                                             

Epoch [17/100], Loss: 0.466, Accuracy: 0.775


                                                             

Epoch [18/100], Loss: 0.466, Accuracy: 0.783


                                                             

Epoch [19/100], Loss: 0.469, Accuracy: 0.778


                                                             

Epoch [20/100], Loss: 0.477, Accuracy: 0.777


                                                             

Epoch [21/100], Loss: 0.472, Accuracy: 0.780


                                                             

Epoch [22/100], Loss: 0.474, Accuracy: 0.776


                                                             

Epoch [23/100], Loss: 0.462, Accuracy: 0.788


                                                             

Epoch [24/100], Loss: 0.469, Accuracy: 0.774


                                                             

Epoch [25/100], Loss: 0.473, Accuracy: 0.774


                                                             

Epoch [26/100], Loss: 0.472, Accuracy: 0.776


                                                             

Epoch [27/100], Loss: 0.483, Accuracy: 0.774


                                                             

Epoch [28/100], Loss: 0.477, Accuracy: 0.785


                                                             

Epoch [29/100], Loss: 0.469, Accuracy: 0.777


                                                             

Epoch [30/100], Loss: 0.464, Accuracy: 0.773


                                                             

Epoch [31/100], Loss: 0.465, Accuracy: 0.787


                                                             

Epoch [32/100], Loss: 0.471, Accuracy: 0.783


                                                             

Epoch [33/100], Loss: 0.468, Accuracy: 0.773


                                                             

Epoch [34/100], Loss: 0.473, Accuracy: 0.767


                                                             

Epoch [35/100], Loss: 0.477, Accuracy: 0.777


                                                             

Epoch [36/100], Loss: 0.484, Accuracy: 0.772


                                                             

Epoch [37/100], Loss: 0.466, Accuracy: 0.781


                                                             

Epoch [38/100], Loss: 0.482, Accuracy: 0.774


                                                             

Epoch [39/100], Loss: 0.470, Accuracy: 0.778


                                                             

Epoch [40/100], Loss: 0.473, Accuracy: 0.775


                                                             

Epoch [41/100], Loss: 0.470, Accuracy: 0.776


                                                             

Epoch [42/100], Loss: 0.477, Accuracy: 0.780


                                                             

Epoch [43/100], Loss: 0.461, Accuracy: 0.785


                                                             

Epoch [44/100], Loss: 0.468, Accuracy: 0.777


                                                             

Epoch [45/100], Loss: 0.478, Accuracy: 0.772


                                                             

Epoch [46/100], Loss: 0.473, Accuracy: 0.777


                                                             

Epoch [47/100], Loss: 0.478, Accuracy: 0.771


                                                             

Epoch [48/100], Loss: 0.475, Accuracy: 0.770


                                                             

Epoch [49/100], Loss: 0.471, Accuracy: 0.772


                                                             

Epoch [50/100], Loss: 0.473, Accuracy: 0.778


                                                             

Epoch [51/100], Loss: 0.479, Accuracy: 0.767


                                                             

Epoch [52/100], Loss: 0.463, Accuracy: 0.785


                                                             

Epoch [53/100], Loss: 0.472, Accuracy: 0.777


                                                             

Epoch [54/100], Loss: 0.490, Accuracy: 0.772


                                                             

Epoch [55/100], Loss: 0.474, Accuracy: 0.777


                                                             

Epoch [56/100], Loss: 0.464, Accuracy: 0.780


                                                             

Epoch [57/100], Loss: 0.477, Accuracy: 0.770


                                                             

Epoch [58/100], Loss: 0.469, Accuracy: 0.772


                                                             

Epoch [59/100], Loss: 0.479, Accuracy: 0.772


                                                             

Epoch [60/100], Loss: 0.485, Accuracy: 0.770


                                                             

Epoch [61/100], Loss: 0.474, Accuracy: 0.783


                                                             

Epoch [62/100], Loss: 0.464, Accuracy: 0.785


                                                             

Epoch [63/100], Loss: 0.470, Accuracy: 0.780


                                                             

Epoch [64/100], Loss: 0.470, Accuracy: 0.780


                                                             

Epoch [65/100], Loss: 0.474, Accuracy: 0.772


                                                             

Epoch [66/100], Loss: 0.464, Accuracy: 0.776


                                                             

Epoch [67/100], Loss: 0.474, Accuracy: 0.773


                                                             

Epoch [68/100], Loss: 0.481, Accuracy: 0.772


                                                             

Epoch [69/100], Loss: 0.469, Accuracy: 0.780


                                                             

Epoch [70/100], Loss: 0.469, Accuracy: 0.777


                                                             

Epoch [71/100], Loss: 0.464, Accuracy: 0.776


                                                             

Epoch [72/100], Loss: 0.479, Accuracy: 0.774


                                                             

Epoch [73/100], Loss: 0.474, Accuracy: 0.780


                                                             

Epoch [74/100], Loss: 0.485, Accuracy: 0.772


                                                             

Epoch [75/100], Loss: 0.463, Accuracy: 0.785


                                                             

Epoch [76/100], Loss: 0.482, Accuracy: 0.772


                                                             

Epoch [77/100], Loss: 0.475, Accuracy: 0.771


                                                             

Epoch [78/100], Loss: 0.467, Accuracy: 0.780


                                                             

Epoch [79/100], Loss: 0.459, Accuracy: 0.788


                                                             

Epoch [80/100], Loss: 0.467, Accuracy: 0.780


                                                             

Epoch [81/100], Loss: 0.463, Accuracy: 0.778


                                                             

Epoch [82/100], Loss: 0.470, Accuracy: 0.774


                                                             

Epoch [83/100], Loss: 0.470, Accuracy: 0.775


                                                             

Epoch [84/100], Loss: 0.479, Accuracy: 0.768


                                                             

Epoch [85/100], Loss: 0.461, Accuracy: 0.783


                                                             

Epoch [86/100], Loss: 0.472, Accuracy: 0.777


                                                             

Epoch [87/100], Loss: 0.474, Accuracy: 0.774


                                                             

Epoch [88/100], Loss: 0.471, Accuracy: 0.769


                                                             

Epoch [89/100], Loss: 0.494, Accuracy: 0.758


                                                             

Epoch [90/100], Loss: 0.478, Accuracy: 0.780


                                                             

Epoch [91/100], Loss: 0.470, Accuracy: 0.783


                                                             

Epoch [92/100], Loss: 0.457, Accuracy: 0.794


                                                             

Epoch [93/100], Loss: 0.468, Accuracy: 0.780


                                                             

Epoch [94/100], Loss: 0.467, Accuracy: 0.783


                                                             

Epoch [95/100], Loss: 0.459, Accuracy: 0.780


                                                             

Epoch [96/100], Loss: 0.476, Accuracy: 0.769


                                                             

Epoch [97/100], Loss: 0.465, Accuracy: 0.777


                                                             

Epoch [98/100], Loss: 0.485, Accuracy: 0.773


                                                             

Epoch [99/100], Loss: 0.479, Accuracy: 0.774


                                                              

Epoch [100/100], Loss: 0.467, Accuracy: 0.777




## 4. Checking on the validation dataset

In [29]:
loss_val = 0.0
acc_val = 0.0

for sample in tqdm(dataloader_val):
    matrix, label = sample[0].float().to(device), sample[1].to(device)  # Convert the data to float32
    pred = model_resnet(matrix)
    loss = criterion(pred, label)

    loss_item = loss.item()
    loss_val += loss_item

    acc_current = accuracy(pred.cpu().float(), label.cpu().float())
    acc_val += acc_current

print(f'Loss: {loss_val/len(dataloader_val):.5f}, Accuracy: {acc_val/len(dataloader_val):.3f}')

100%|██████████| 3/3 [00:00<00:00,  8.09it/s]

Loss: 0.49940, Accuracy: 0.760



