# Wine aroma prediction using Vision Transformer
Here will be used Vision Transformer to predict wine aroma from the previously obtained wine composition matrices

In [None]:
#Import the necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split

import torchvision as tv

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm

from torch.cuda.amp import autocast, GradScaler

import plotly.graph_objects as go
import seaborn as sns
import plotly.express as px
import pandas as pd

from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch
import torch.nn.functional as F
from torchvision.transforms import functional as TF

## 1. Working with data
At this stage, Load the previously prepared data, and process the data for further neural networking 

In [2]:
# Loading data
X_array =  np.load('X_array.npy')
Y_array =  np.load('Y_array.npy')

In [3]:
# Separation of data into training, validation and test data in the ratio of 70:20:10
X_train, X_test, y_train, y_test = train_test_split(X_array, Y_array, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

In [4]:
# Conversion to a tensor
X_train = torch.tensor(X_train)
y_train = torch.tensor(y_train)
X_val = torch.tensor(X_val)
y_val = torch.tensor(y_val)
X_test = torch.tensor(X_test)
y_test = torch.tensor(y_test)

In [5]:
print(len(X_train))
print(len(y_train))
print(len(X_val))
print(len(y_val))
print(len(X_test))
print(len(y_test))

323
323
81
81
45
45


Since Vision Transformer has been trained on 224x224 images, need to bring images to the same size by padding (adding white pixels in equal proportions to the bottom, top and sides of the image), a method chosen to preserve chemical information without distorting it through interpolation

In [None]:
# Create a custom dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, x, y):
        # Add channels if there are none (i.e. if the images have dimension [N, H, W]).
        if x.dim() == 3:  
            x = x.unsqueeze(1)  # Add channel dimensionality [N, 1, H, W]

        self.x = self.pad_to_224(x)  # Adding padding
        self.y = y

    def pad_to_224(self, x):
        padded_images = []
        for img in x:
            # Check that the input images are of size [1, 44, 100]
            _, h, w = img.shape
            assert h == 44 and w == 100, "Incorrect size of the input image"

            # Applying padding
            img_padded = TF.pad(img, (62, 90), fill=255)  # white background

            padded_images.append(img_padded)

        return torch.stack(padded_images)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

# Create Dataset and DataLoader
dataset_train = CustomDataset(X_train, y_train)
dataset_val = CustomDataset(X_val, y_val)
dataset_test = CustomDataset(X_test, y_test)

dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size=32, shuffle=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=32, shuffle=True)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=32, shuffle=False)


In [7]:
for x, y in dataloader_train:
    print(x.size())

torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([32, 1, 224, 224])
torch.Size([3, 1, 224, 224])


## 3. Loading of the pre-trained model and its additional training
At this stage will load the pre-trained model, change the necessary parameters, and retrain it on data

In [None]:
from transformers import ViTConfig, ViTForImageClassification

# Downloading the configuration and updating it
model_name = "google/vit-base-patch16-224"
config = ViTConfig.from_pretrained(model_name)

# Configuration setting for single input channel operation
config.num_channels = 1
config.num_labels = 10 


# Create a model with a new configuration
model = ViTForImageClassification(config)

# Replace the patch projection layer with a new one with one input channel
new_projection_layer = torch.nn.Conv2d(
    in_channels=config.num_channels,  # One input channel
    out_channels=config.hidden_size,
    kernel_size=(16, 16),
    stride=(16, 16)
)
model.vit.embeddings.patch_embeddings.projection = new_projection_layer

# Transfer the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(1, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [10]:
# Definition of loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999))


In [11]:
# Accuracy function
def accuracy(pred, label, threshold=0.5):
    pred = torch.sigmoid(pred)
    pred_labels = (pred > threshold).float()  # Threshold the predictions
    correct = (pred_labels == label).sum().item()  # Compare predictions with labels
    total = label.size(0) * label.size(1)  # Total number of labels
    return correct / total

In [None]:
# Model retraining

num_epochs = 100 

# Transfer the model to the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    running_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    acc_val = 0.0
    
    for matrix, labels in tqdm(dataloader_train, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False):
        matrix, labels = matrix.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = model(matrix)

        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()
        
        # Clearing GPU memory
        torch.cuda.empty_cache()

        loss_item = loss.item()
        running_loss += loss_item
                
        acc_current = accuracy(outputs.logits.cpu().float(), labels.cpu().float())
        acc_val += acc_current

      
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader_train):.3f}, Accuracy: {acc_val/len(dataloader_train):.3f}") 

  context_layer = torch.nn.functional.scaled_dot_product_attention(
                                                            

Epoch [1/100], Loss: 0.535, Accuracy: 0.730


                                                            

Epoch [2/100], Loss: 0.477, Accuracy: 0.779


                                                            

Epoch [3/100], Loss: 0.484, Accuracy: 0.777


                                                            

Epoch [4/100], Loss: 0.476, Accuracy: 0.773


                                                            

Epoch [5/100], Loss: 0.472, Accuracy: 0.785


                                                            

Epoch [6/100], Loss: 0.467, Accuracy: 0.782


                                                            

Epoch [7/100], Loss: 0.488, Accuracy: 0.762


                                                            

Epoch [8/100], Loss: 0.496, Accuracy: 0.751


                                                            

Epoch [9/100], Loss: 0.484, Accuracy: 0.765


                                                             

Epoch [10/100], Loss: 0.478, Accuracy: 0.771


                                                             

Epoch [11/100], Loss: 0.477, Accuracy: 0.776


                                                             

Epoch [12/100], Loss: 0.485, Accuracy: 0.766


                                                             

Epoch [13/100], Loss: 0.468, Accuracy: 0.776


                                                             

Epoch [14/100], Loss: 0.486, Accuracy: 0.760


                                                             

Epoch [15/100], Loss: 0.473, Accuracy: 0.779


                                                             

Epoch [16/100], Loss: 0.478, Accuracy: 0.774


                                                             

Epoch [17/100], Loss: 0.476, Accuracy: 0.770


                                                             

Epoch [18/100], Loss: 0.473, Accuracy: 0.773


                                                             

Epoch [19/100], Loss: 0.466, Accuracy: 0.783


                                                             

Epoch [20/100], Loss: 0.463, Accuracy: 0.784


                                                             

Epoch [21/100], Loss: 0.481, Accuracy: 0.770


                                                             

Epoch [22/100], Loss: 0.465, Accuracy: 0.783


                                                             

Epoch [23/100], Loss: 0.488, Accuracy: 0.777


                                                             

Epoch [24/100], Loss: 0.480, Accuracy: 0.770


                                                             

Epoch [25/100], Loss: 0.470, Accuracy: 0.774


                                                             

Epoch [26/100], Loss: 0.472, Accuracy: 0.770


                                                             

Epoch [27/100], Loss: 0.483, Accuracy: 0.771


                                                             

Epoch [28/100], Loss: 0.487, Accuracy: 0.759


                                                             

Epoch [29/100], Loss: 0.483, Accuracy: 0.772


                                                             

Epoch [30/100], Loss: 0.484, Accuracy: 0.767


                                                             

Epoch [31/100], Loss: 0.478, Accuracy: 0.774


                                                             

Epoch [32/100], Loss: 0.491, Accuracy: 0.767


                                                             

Epoch [33/100], Loss: 0.462, Accuracy: 0.789


                                                             

Epoch [34/100], Loss: 0.476, Accuracy: 0.769


                                                             

Epoch [35/100], Loss: 0.471, Accuracy: 0.776


                                                             

Epoch [36/100], Loss: 0.477, Accuracy: 0.771


                                                             

Epoch [37/100], Loss: 0.474, Accuracy: 0.780


                                                             

Epoch [38/100], Loss: 0.494, Accuracy: 0.764


                                                             

Epoch [39/100], Loss: 0.475, Accuracy: 0.777


                                                             

Epoch [40/100], Loss: 0.479, Accuracy: 0.770


                                                             

Epoch [41/100], Loss: 0.480, Accuracy: 0.774


                                                             

Epoch [42/100], Loss: 0.480, Accuracy: 0.769


                                                             

Epoch [43/100], Loss: 0.473, Accuracy: 0.780


                                                             

Epoch [44/100], Loss: 0.465, Accuracy: 0.788


                                                             

Epoch [45/100], Loss: 0.469, Accuracy: 0.783


                                                             

Epoch [46/100], Loss: 0.486, Accuracy: 0.777


                                                             

Epoch [47/100], Loss: 0.479, Accuracy: 0.779


                                                             

Epoch [48/100], Loss: 0.475, Accuracy: 0.780


                                                             

Epoch [49/100], Loss: 0.469, Accuracy: 0.780


                                                             

Epoch [50/100], Loss: 0.471, Accuracy: 0.769


                                                             

Epoch [51/100], Loss: 0.493, Accuracy: 0.762


                                                             

Epoch [52/100], Loss: 0.484, Accuracy: 0.774


                                                             

Epoch [53/100], Loss: 0.468, Accuracy: 0.780


                                                             

Epoch [54/100], Loss: 0.482, Accuracy: 0.772


                                                             

Epoch [55/100], Loss: 0.486, Accuracy: 0.783


                                                             

Epoch [56/100], Loss: 0.477, Accuracy: 0.777


                                                             

Epoch [57/100], Loss: 0.470, Accuracy: 0.776


                                                             

Epoch [58/100], Loss: 0.479, Accuracy: 0.770


                                                             

Epoch [59/100], Loss: 0.475, Accuracy: 0.772


                                                             

Epoch [60/100], Loss: 0.475, Accuracy: 0.765


                                                             

Epoch [61/100], Loss: 0.476, Accuracy: 0.758


                                                             

Epoch [62/100], Loss: 0.471, Accuracy: 0.785


                                                             

Epoch [63/100], Loss: 0.470, Accuracy: 0.780


                                                             

Epoch [64/100], Loss: 0.470, Accuracy: 0.770


                                                             

Epoch [65/100], Loss: 0.481, Accuracy: 0.774


                                                             

Epoch [66/100], Loss: 0.479, Accuracy: 0.773


                                                             

Epoch [67/100], Loss: 0.474, Accuracy: 0.777


                                                             

Epoch [68/100], Loss: 0.468, Accuracy: 0.783


                                                             

Epoch [69/100], Loss: 0.464, Accuracy: 0.785


                                                             

Epoch [70/100], Loss: 0.473, Accuracy: 0.772


                                                             

Epoch [71/100], Loss: 0.489, Accuracy: 0.761


                                                             

Epoch [72/100], Loss: 0.467, Accuracy: 0.777


                                                             

Epoch [73/100], Loss: 0.458, Accuracy: 0.785


                                                             

Epoch [74/100], Loss: 0.470, Accuracy: 0.785


                                                             

Epoch [75/100], Loss: 0.469, Accuracy: 0.780


                                                             

Epoch [76/100], Loss: 0.477, Accuracy: 0.780


                                                             

Epoch [77/100], Loss: 0.467, Accuracy: 0.783


                                                             

Epoch [78/100], Loss: 0.460, Accuracy: 0.779


                                                             

Epoch [79/100], Loss: 0.481, Accuracy: 0.763


                                                             

Epoch [80/100], Loss: 0.474, Accuracy: 0.778


                                                             

Epoch [81/100], Loss: 0.466, Accuracy: 0.780


                                                             

Epoch [82/100], Loss: 0.466, Accuracy: 0.780


                                                             

Epoch [83/100], Loss: 0.464, Accuracy: 0.781


                                                             

Epoch [84/100], Loss: 0.484, Accuracy: 0.767


                                                             

Epoch [85/100], Loss: 0.475, Accuracy: 0.777


                                                             

Epoch [86/100], Loss: 0.488, Accuracy: 0.760


                                                             

Epoch [87/100], Loss: 0.463, Accuracy: 0.787


                                                             

Epoch [88/100], Loss: 0.477, Accuracy: 0.776


                                                             

Epoch [89/100], Loss: 0.485, Accuracy: 0.766


                                                             

Epoch [90/100], Loss: 0.472, Accuracy: 0.767


                                                             

Epoch [91/100], Loss: 0.482, Accuracy: 0.776


                                                             

Epoch [92/100], Loss: 0.482, Accuracy: 0.770


                                                             

Epoch [93/100], Loss: 0.481, Accuracy: 0.774


                                                             

Epoch [94/100], Loss: 0.464, Accuracy: 0.785


                                                             

Epoch [95/100], Loss: 0.465, Accuracy: 0.780


                                                             

Epoch [96/100], Loss: 0.475, Accuracy: 0.772


                                                             

Epoch [97/100], Loss: 0.486, Accuracy: 0.775


                                                             

Epoch [98/100], Loss: 0.471, Accuracy: 0.777


                                                             

Epoch [99/100], Loss: 0.472, Accuracy: 0.778


                                                              

Epoch [100/100], Loss: 0.468, Accuracy: 0.777




In [13]:
torch.cuda.empty_cache()

## 4. Checking on the validation dataset

In [None]:
loss_val = 0.0
acc_val = 0.0
for sample in tqdm(dataloader_val):
    matrix, label = sample[0].to(device), sample[1].to(device)
  
    pred = model(matrix)
    loss = criterion(pred.logits, label)

    torch.cuda.empty_cache()

    loss_item = loss.item()
    loss_val += loss_item


    acc_current = accuracy(pred.logits.cpu().float(), label.cpu().float())
    acc_val += acc_current


print(f'Loss: {loss_val/len(dataloader_val):.3f}, Accuracy: {acc_val/len(dataloader_val):.3f}')

100%|██████████| 3/3 [00:04<00:00,  1.61s/it]

Loss: 0.501, Accuracy: 0.758



