# Wine aroma prediction using CNNs
Here we use CNN to predict wine aroma from the previously obtained wine composition matrices

In [None]:
#Import the necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
import torchvision as tv
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm
from torch.cuda.amp import autocast, GradScaler
import plotly.graph_objects as go
import seaborn as sns
import plotly.express as px
import pandas as pd 

## 1. Working with data
At this stage, we load the previously prepared data, and process the data for further neural networking 

In [537]:
# Loading data
X_array =  np.load('X_array.npy')
Y_array =  np.load('Y_array.npy')

In [538]:
# Separation of data into training, validation and test data in the ratio of 70:20:10
X_train, X_test, y_train, y_test = train_test_split(X_array, Y_array, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

In [539]:
# Conversion to a tensor
X_train = torch.tensor(X_train)
y_train = torch.tensor(y_train)
X_val = torch.tensor(X_val)
y_val = torch.tensor(y_val)
X_test = torch.tensor(X_test)
y_test = torch.tensor(y_test)

In [540]:
print(len(X_train))
print(len(y_train))
print(len(X_val))
print(len(y_val))
print(len(X_test))
print(len(y_test))

323
323
81
81
45
45


In [541]:
# Create a custom dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [542]:
# Create Dataset and DataLoader
dataset_train = CustomDataset(X_train, y_train)
dataset_val = CustomDataset(X_val, y_val)
dataset_test = CustomDataset(X_test, y_test)
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=32, shuffle=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=32, shuffle=False)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=32, shuffle=False)

## 2. CNN architecture
At this step we design the architecture of the neural network, in this file the optimal architecture is specified

In [543]:
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        # Define convolution layers with batch normalization
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        
        # Pulling
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 11 * 25, 512)  
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 10)
        
        # Dropout to prevent overtraining
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        x = self.pool(F.leaky_relu(self.bn1(self.conv1(x))))
        x = self.pool(F.leaky_relu(self.bn2(self.conv2(x))))
 
        # Straightening before feeding to the full-link layer
        x = x.view(-1, 128 * 11 * 25)  # Flatten
        
        # Fully connected layers with activation and dropout
        x = F.leaky_relu(self.fc1(x))
        x = self.dropout(x)  # Dropout after the first full-link layer
        
        x = F.leaky_relu(self.fc2(x))
        x = self.fc3(x)
        
        return x
    

# Creating an instance of a neural network
net = ConvNet()
# Print the architecture of the neural network
print(net)

ConvNet(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (fc1): Linear(in_features=35200, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)


## 3. Neural network training
Set the training parameters and train the neural network

In [None]:
loss_fn = nn.BCEWithLogitsLoss() 
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.999), weight_decay=1e-5) 
scheduler = torch.optim.lr_scheduler.ExponentialLR(
    optimizer,
    gamma = 0.6
)

In [545]:
# Accuracy function
def accuracy(pred, label, threshold=0.5):
    pred = torch.sigmoid(pred)
    pred_labels = (pred > threshold).float()  # Threshold the predictions
    correct = (pred_labels == label).sum().item()  # Compare predictions with labels
    total = label.size(0) * label.size(1)  # Total number of labels
    return correct / total

In [546]:
# Transferring the neural network to cuda
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = net.to(device)
loss_fn = loss_fn.to(device)

In [547]:
use_amp = True
scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False


`torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.



In [None]:
# Neural network training
epochs = 100 

for epoch in range(epochs):
    loss_val = 0.0
    acc_val = 0.0
    for sample in tqdm(dataloader_train):
        matrix, label = sample[0].to(device), sample[1].to(device) #Transferring to cuda
        matrix = matrix.float()
        optimizer.zero_grad()

        # Let's add channel measurement to the data
        matrix = matrix.unsqueeze(1)
  
        with autocast(use_amp):
            pred = model(matrix)
            loss = loss_fn(pred, label.float()) 

        scaler.scale(loss).backward()
        loss_item = loss.item()
        loss_val += loss_item

        scaler.step(optimizer)
        scaler.update()

        # Accuracy calculation
        acc_current = accuracy(pred.cpu().float(), label.cpu().float())
        acc_val += acc_current

    print(f'Epoch: [{epoch+1}/{epochs}], Loss: {loss_val/len(dataloader_train):.5f}, Accuracy: {acc_val/len(dataloader_train):.3f}')


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.

100%|██████████| 11/11 [00:00<00:00, 19.00it/s]


Epoch: [1/100], Loss: 0.56558, Accuracy: 0.739


100%|██████████| 11/11 [00:00<00:00, 41.37it/s]


Epoch: [2/100], Loss: 0.46993, Accuracy: 0.791


100%|██████████| 11/11 [00:00<00:00, 42.30it/s]


Epoch: [3/100], Loss: 0.46761, Accuracy: 0.789


100%|██████████| 11/11 [00:00<00:00, 43.32it/s]


Epoch: [4/100], Loss: 0.47600, Accuracy: 0.787


100%|██████████| 11/11 [00:00<00:00, 43.73it/s]


Epoch: [5/100], Loss: 0.45225, Accuracy: 0.808


100%|██████████| 11/11 [00:00<00:00, 42.50it/s]


Epoch: [6/100], Loss: 0.45559, Accuracy: 0.799


100%|██████████| 11/11 [00:00<00:00, 42.12it/s]


Epoch: [7/100], Loss: 0.44377, Accuracy: 0.797


100%|██████████| 11/11 [00:00<00:00, 38.05it/s]


Epoch: [8/100], Loss: 0.44122, Accuracy: 0.805


100%|██████████| 11/11 [00:00<00:00, 42.87it/s]


Epoch: [9/100], Loss: 0.44220, Accuracy: 0.803


100%|██████████| 11/11 [00:00<00:00, 43.51it/s]


Epoch: [10/100], Loss: 0.47196, Accuracy: 0.795


100%|██████████| 11/11 [00:00<00:00, 43.61it/s]


Epoch: [11/100], Loss: 0.44232, Accuracy: 0.801


100%|██████████| 11/11 [00:00<00:00, 44.97it/s]


Epoch: [12/100], Loss: 0.41526, Accuracy: 0.818


100%|██████████| 11/11 [00:00<00:00, 43.96it/s]


Epoch: [13/100], Loss: 0.41769, Accuracy: 0.811


100%|██████████| 11/11 [00:00<00:00, 44.34it/s]


Epoch: [14/100], Loss: 0.41271, Accuracy: 0.815


100%|██████████| 11/11 [00:00<00:00, 45.01it/s]


Epoch: [15/100], Loss: 0.42459, Accuracy: 0.810


100%|██████████| 11/11 [00:00<00:00, 44.91it/s]


Epoch: [16/100], Loss: 0.42668, Accuracy: 0.811


100%|██████████| 11/11 [00:00<00:00, 43.31it/s]


Epoch: [17/100], Loss: 0.42047, Accuracy: 0.805


100%|██████████| 11/11 [00:00<00:00, 45.06it/s]


Epoch: [18/100], Loss: 0.42030, Accuracy: 0.812


100%|██████████| 11/11 [00:00<00:00, 44.90it/s]


Epoch: [19/100], Loss: 0.40184, Accuracy: 0.822


100%|██████████| 11/11 [00:00<00:00, 44.86it/s]


Epoch: [20/100], Loss: 0.41439, Accuracy: 0.817


100%|██████████| 11/11 [00:00<00:00, 43.81it/s]


Epoch: [21/100], Loss: 0.41454, Accuracy: 0.819


100%|██████████| 11/11 [00:00<00:00, 45.53it/s]


Epoch: [22/100], Loss: 0.40979, Accuracy: 0.816


100%|██████████| 11/11 [00:00<00:00, 44.49it/s]


Epoch: [23/100], Loss: 0.38502, Accuracy: 0.827


100%|██████████| 11/11 [00:00<00:00, 44.61it/s]


Epoch: [24/100], Loss: 0.38992, Accuracy: 0.826


100%|██████████| 11/11 [00:00<00:00, 43.44it/s]


Epoch: [25/100], Loss: 0.41071, Accuracy: 0.827


100%|██████████| 11/11 [00:00<00:00, 45.44it/s]


Epoch: [26/100], Loss: 0.41018, Accuracy: 0.803


100%|██████████| 11/11 [00:00<00:00, 45.20it/s]


Epoch: [27/100], Loss: 0.38479, Accuracy: 0.832


100%|██████████| 11/11 [00:00<00:00, 44.54it/s]


Epoch: [28/100], Loss: 0.38395, Accuracy: 0.834


100%|██████████| 11/11 [00:00<00:00, 44.97it/s]


Epoch: [29/100], Loss: 0.37880, Accuracy: 0.835


100%|██████████| 11/11 [00:00<00:00, 44.29it/s]


Epoch: [30/100], Loss: 0.37850, Accuracy: 0.838


100%|██████████| 11/11 [00:00<00:00, 45.17it/s]


Epoch: [31/100], Loss: 0.37846, Accuracy: 0.838


100%|██████████| 11/11 [00:00<00:00, 44.75it/s]


Epoch: [32/100], Loss: 0.37533, Accuracy: 0.842


100%|██████████| 11/11 [00:00<00:00, 44.25it/s]


Epoch: [33/100], Loss: 0.38681, Accuracy: 0.829


100%|██████████| 11/11 [00:00<00:00, 43.44it/s]


Epoch: [34/100], Loss: 0.38687, Accuracy: 0.835


100%|██████████| 11/11 [00:00<00:00, 42.55it/s]


Epoch: [35/100], Loss: 0.38706, Accuracy: 0.830


100%|██████████| 11/11 [00:00<00:00, 44.33it/s]


Epoch: [36/100], Loss: 0.38769, Accuracy: 0.827


100%|██████████| 11/11 [00:00<00:00, 44.18it/s]


Epoch: [37/100], Loss: 0.39817, Accuracy: 0.818


100%|██████████| 11/11 [00:00<00:00, 45.45it/s]


Epoch: [38/100], Loss: 0.39046, Accuracy: 0.826


100%|██████████| 11/11 [00:00<00:00, 43.83it/s]


Epoch: [39/100], Loss: 0.37933, Accuracy: 0.836


100%|██████████| 11/11 [00:00<00:00, 45.45it/s]


Epoch: [40/100], Loss: 0.37829, Accuracy: 0.834


100%|██████████| 11/11 [00:00<00:00, 43.65it/s]


Epoch: [41/100], Loss: 0.37847, Accuracy: 0.828


100%|██████████| 11/11 [00:00<00:00, 42.80it/s]


Epoch: [42/100], Loss: 0.38946, Accuracy: 0.832


100%|██████████| 11/11 [00:00<00:00, 44.00it/s]


Epoch: [43/100], Loss: 0.37127, Accuracy: 0.843


100%|██████████| 11/11 [00:00<00:00, 43.62it/s]


Epoch: [44/100], Loss: 0.38127, Accuracy: 0.834


100%|██████████| 11/11 [00:00<00:00, 45.35it/s]


Epoch: [45/100], Loss: 0.36797, Accuracy: 0.835


100%|██████████| 11/11 [00:00<00:00, 43.56it/s]


Epoch: [46/100], Loss: 0.36663, Accuracy: 0.840


100%|██████████| 11/11 [00:00<00:00, 45.65it/s]


Epoch: [47/100], Loss: 0.37591, Accuracy: 0.836


100%|██████████| 11/11 [00:00<00:00, 44.71it/s]


Epoch: [48/100], Loss: 0.39582, Accuracy: 0.830


100%|██████████| 11/11 [00:00<00:00, 43.47it/s]


Epoch: [49/100], Loss: 0.38312, Accuracy: 0.828


100%|██████████| 11/11 [00:00<00:00, 45.48it/s]


Epoch: [50/100], Loss: 0.37290, Accuracy: 0.835


100%|██████████| 11/11 [00:00<00:00, 45.98it/s]


Epoch: [51/100], Loss: 0.36802, Accuracy: 0.842


100%|██████████| 11/11 [00:00<00:00, 45.33it/s]


Epoch: [52/100], Loss: 0.37523, Accuracy: 0.833


100%|██████████| 11/11 [00:00<00:00, 45.32it/s]


Epoch: [53/100], Loss: 0.37143, Accuracy: 0.842


100%|██████████| 11/11 [00:00<00:00, 43.47it/s]


Epoch: [54/100], Loss: 0.36199, Accuracy: 0.839


100%|██████████| 11/11 [00:00<00:00, 43.60it/s]


Epoch: [55/100], Loss: 0.35812, Accuracy: 0.839


100%|██████████| 11/11 [00:00<00:00, 44.51it/s]


Epoch: [56/100], Loss: 0.34820, Accuracy: 0.849


100%|██████████| 11/11 [00:00<00:00, 44.84it/s]


Epoch: [57/100], Loss: 0.37283, Accuracy: 0.836


100%|██████████| 11/11 [00:00<00:00, 45.25it/s]


Epoch: [58/100], Loss: 0.36171, Accuracy: 0.834


100%|██████████| 11/11 [00:00<00:00, 45.87it/s]


Epoch: [59/100], Loss: 0.35620, Accuracy: 0.838


100%|██████████| 11/11 [00:00<00:00, 45.53it/s]


Epoch: [60/100], Loss: 0.34988, Accuracy: 0.850


100%|██████████| 11/11 [00:00<00:00, 44.97it/s]


Epoch: [61/100], Loss: 0.34272, Accuracy: 0.842


100%|██████████| 11/11 [00:00<00:00, 44.14it/s]


Epoch: [62/100], Loss: 0.34978, Accuracy: 0.853


100%|██████████| 11/11 [00:00<00:00, 44.74it/s]


Epoch: [63/100], Loss: 0.37276, Accuracy: 0.839


100%|██████████| 11/11 [00:00<00:00, 45.64it/s]


Epoch: [64/100], Loss: 0.36079, Accuracy: 0.843


100%|██████████| 11/11 [00:00<00:00, 45.70it/s]


Epoch: [65/100], Loss: 0.35002, Accuracy: 0.845


100%|██████████| 11/11 [00:00<00:00, 45.52it/s]


Epoch: [66/100], Loss: 0.34930, Accuracy: 0.850


100%|██████████| 11/11 [00:00<00:00, 47.36it/s]


Epoch: [67/100], Loss: 0.36479, Accuracy: 0.837


100%|██████████| 11/11 [00:00<00:00, 45.39it/s]


Epoch: [68/100], Loss: 0.35039, Accuracy: 0.842


100%|██████████| 11/11 [00:00<00:00, 44.32it/s]


Epoch: [69/100], Loss: 0.34176, Accuracy: 0.839


100%|██████████| 11/11 [00:00<00:00, 45.53it/s]


Epoch: [70/100], Loss: 0.35750, Accuracy: 0.843


100%|██████████| 11/11 [00:00<00:00, 45.75it/s]


Epoch: [71/100], Loss: 0.35361, Accuracy: 0.849


100%|██████████| 11/11 [00:00<00:00, 45.54it/s]


Epoch: [72/100], Loss: 0.34967, Accuracy: 0.849


100%|██████████| 11/11 [00:00<00:00, 46.03it/s]


Epoch: [73/100], Loss: 0.35126, Accuracy: 0.846


100%|██████████| 11/11 [00:00<00:00, 45.08it/s]


Epoch: [74/100], Loss: 0.33421, Accuracy: 0.850


100%|██████████| 11/11 [00:00<00:00, 44.87it/s]


Epoch: [75/100], Loss: 0.34229, Accuracy: 0.849


100%|██████████| 11/11 [00:00<00:00, 44.01it/s]


Epoch: [76/100], Loss: 0.34595, Accuracy: 0.844


100%|██████████| 11/11 [00:00<00:00, 45.56it/s]


Epoch: [77/100], Loss: 0.34760, Accuracy: 0.842


100%|██████████| 11/11 [00:00<00:00, 44.76it/s]


Epoch: [78/100], Loss: 0.34914, Accuracy: 0.849


100%|██████████| 11/11 [00:00<00:00, 45.66it/s]


Epoch: [79/100], Loss: 0.35620, Accuracy: 0.836


100%|██████████| 11/11 [00:00<00:00, 45.76it/s]


Epoch: [80/100], Loss: 0.32543, Accuracy: 0.860


100%|██████████| 11/11 [00:00<00:00, 45.09it/s]


Epoch: [81/100], Loss: 0.32206, Accuracy: 0.857


100%|██████████| 11/11 [00:00<00:00, 45.53it/s]


Epoch: [82/100], Loss: 0.32974, Accuracy: 0.852


100%|██████████| 11/11 [00:00<00:00, 45.29it/s]


Epoch: [83/100], Loss: 0.31864, Accuracy: 0.856


100%|██████████| 11/11 [00:00<00:00, 44.58it/s]


Epoch: [84/100], Loss: 0.33867, Accuracy: 0.849


100%|██████████| 11/11 [00:00<00:00, 45.92it/s]


Epoch: [85/100], Loss: 0.33346, Accuracy: 0.853


100%|██████████| 11/11 [00:00<00:00, 44.74it/s]


Epoch: [86/100], Loss: 0.34672, Accuracy: 0.843


100%|██████████| 11/11 [00:00<00:00, 45.21it/s]


Epoch: [87/100], Loss: 0.33003, Accuracy: 0.855


100%|██████████| 11/11 [00:00<00:00, 44.08it/s]


Epoch: [88/100], Loss: 0.34615, Accuracy: 0.839


100%|██████████| 11/11 [00:00<00:00, 43.13it/s]


Epoch: [89/100], Loss: 0.32171, Accuracy: 0.855


100%|██████████| 11/11 [00:00<00:00, 45.03it/s]


Epoch: [90/100], Loss: 0.32570, Accuracy: 0.855


100%|██████████| 11/11 [00:00<00:00, 44.86it/s]


Epoch: [91/100], Loss: 0.32983, Accuracy: 0.847


100%|██████████| 11/11 [00:00<00:00, 45.73it/s]


Epoch: [92/100], Loss: 0.31644, Accuracy: 0.861


100%|██████████| 11/11 [00:00<00:00, 44.79it/s]


Epoch: [93/100], Loss: 0.31785, Accuracy: 0.859


100%|██████████| 11/11 [00:00<00:00, 45.93it/s]


Epoch: [94/100], Loss: 0.32987, Accuracy: 0.855


100%|██████████| 11/11 [00:00<00:00, 45.44it/s]


Epoch: [95/100], Loss: 0.31491, Accuracy: 0.859


100%|██████████| 11/11 [00:00<00:00, 43.66it/s]


Epoch: [96/100], Loss: 0.31248, Accuracy: 0.857


100%|██████████| 11/11 [00:00<00:00, 45.63it/s]


Epoch: [97/100], Loss: 0.31816, Accuracy: 0.861


100%|██████████| 11/11 [00:00<00:00, 45.61it/s]


Epoch: [98/100], Loss: 0.34349, Accuracy: 0.847


100%|██████████| 11/11 [00:00<00:00, 45.54it/s]


Epoch: [99/100], Loss: 0.32506, Accuracy: 0.856


100%|██████████| 11/11 [00:00<00:00, 45.62it/s]

Epoch: [100/100], Loss: 0.32920, Accuracy: 0.853





## 4. Checking on the validation dataset

In [549]:
loss_val = 0.0
acc_val = 0.0
for sample in tqdm(dataloader_val):
    matrix, label = sample[0].to(device), sample[1].to(device)
    matrix = matrix.float()
    # Add channel measurement to the data
    matrix = matrix.unsqueeze(1)
  
    with autocast(use_amp):
        pred = model(matrix)
        loss = loss_fn(pred, label)

    loss_item = loss.item()
    loss_val += loss_item


    acc_current = accuracy(pred.cpu().float(), label.cpu().float())
    acc_val += acc_current


print(f'Loss: {loss_val/len(dataloader_val):.5f}, Accuracy: {acc_val/len(dataloader_val):.3f}')


`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.

100%|██████████| 3/3 [00:00<00:00, 120.04it/s]

Loss: 0.58933, Accuracy: 0.779



