# Chapter 4 - Train on Multimodal Data

Here we will show how to train the model on a folder of images along with associated metadata.

## Data Loader
Again, the first thing we need to do is define our dataloader (what kind of data we expect).  Different from chapter 2, we expect there to be an additional csv file with our age label information as well.  

Typically this data should be split into three different sets, a training set, a validation set, and a testing set.  The training set (\~ 60%-70%), as the name suggests is used to actually train the model.  The validation set (\~10%-20%) is used during training to choose the best performing model.  This is necessary since the model changes at each step of the training phase and the model at the very end of training may not be the best due to overfitting.  Finally, the testing set (\~ 10%-20%) is used to evaluate the actual model performance on unseen data (like accuracy).  For this example, we will use the same set of data for all three purposes.

![image.png](attachment:image.png)

In [53]:
import os
from os import listdir
from os.path import isfile, join
import numpy as np
from PIL import Image
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data.dataset import Dataset  # For custom datasets
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import pandas as pd

class FishTestDataset(Dataset):
    def __init__(self, image_dir, csv_path, transform=None):

        # Read the csv file
        self.data_info = pd.read_csv(csv_path, header=0)
        
        # Get the directory dataset images
        self.image_dir = image_dir

        # Get the transform methods
        self.transforms = transform


        # Image Name
        self.image_name = np.asarray(self.data_info.iloc[:, 0])
        
        # Otolith length
        self.length = np.asarray(self.data_info.iloc[:, 1])

        # Otolith weight
        self.wt = np.asarray(self.data_info.iloc[:, 2])

        # Month
        self.month = np.asarray(self.data_info.iloc[:, 3])
        
        # Fish Age
        self.age = np.asarray(self.data_info.iloc[:, 4])


    def __len__(self):
        return len(self.image_name)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, str(self.image_name[index]))
        image = Image.open(img_path)
        
        wt_l_m = torch.tensor([(self.wt[index] - 163)/(82), (self.length[index] - 211)/ (35.5), (self.month[index]-7.4)/(1.9)])

        
        if(self.age[index] < 5):
          label_age = self.age[index]
        else:
          label_age = 4
            
        if self.transforms:
            image = self.transforms(image)

        return (image,wt_l_m) , self.image_name[index], label_age
        
data_dir = 'cropped'
csv_path = "train.csv"
data_transforms = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
train_dataset = FishTestDataset( data_dir, csv_path, data_transforms)
test_dataset = FishTestDataset( data_dir, csv_path, data_transforms)
val_dataset = FishTestDataset( data_dir, csv_path, data_transforms)
train_loader = DataLoader(train_dataset, batch_size=24, shuffle=True, drop_last=False)
val_loader = DataLoader(test_dataset, batch_size=24, shuffle=False, drop_last=False)
test_loader = DataLoader(val_dataset, batch_size=24, shuffle=False, drop_last=False)

## Adjust Model Architecture for Multi-Modal Data

We copy the basic resnet model building block code.

In [None]:
# function creating a 3x3 convolutional layer
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """
    Function to create a 3x3 convolutional layer with padding.

    Args:
    - in_planes (int): Number of input channels.
    - out_planes (int): Number of output channels.
    - stride (int): Stride value for the convolution (default: 1).
    - groups (int): Number of groups for grouped convolution (default: 1).
    - dilation (int): Dilation rate for the convolution (default: 1).

    Returns:
    - conv_layer (nn.Conv2d): The created 3x3 convolutional layer.
    """
    # Create a 3x3 convolutional layer with the specified parameters
    conv_layer = nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )

    return conv_layer

# function creating a 1x1 convolutional layer
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """
    Function to create a 1x1 convolutional layer.

    Args:
    - in_planes (int): Number of input channels.
    - out_planes (int): Number of output channels.
    - stride (int): Stride value for the convolution (default: 1).

    Returns:
    - conv_layer (nn.Conv2d): The created 1x1 convolutional layer.
    """
    # Create a 1x1 convolutional layer with the specified parameters
    conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

    return conv_layer

# module to define a BasicBlock residual block for the resnet model
class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        """
        Basic residual block implementation used in ResNet.

        Args:
        - inplanes (int): Number of input channels.
        - planes (int): Number of output channels.
        - stride (int): Stride value for the convolutional layers (default: 1).
        - downsample (nn.Module, optional): Downsample module (default: None).
        - groups (int): Number of groups for grouped convolution (default: 1).
        - base_width (int): Base width for grouped convolution (default: 64).
        - dilation (int): Dilation rate for dilated convolution (default: 1).
        - norm_layer (Callable[..., nn.Module], optional): Normalization layer (default: nn.BatchNorm2d).
        """
        super(BasicBlock, self).__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")

        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")

        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)  # 3x3 convolutional layer
        self.bn1 = norm_layer(planes)  # Batch normalization
        self.relu = nn.ReLU(inplace=True)  # ReLU activation function
        self.conv2 = conv3x3(planes, planes)  # 3x3 convolutional layer
        self.bn2 = norm_layer(planes)  # Batch normalization
        self.downsample = downsample  # Downsample module
        self.stride = stride  # Stride value for the convolutional layers

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the BasicBlock.

        Args:
        - x (Tensor): Input tensor.

        Returns:
        - out (Tensor): Output tensor.
        """
        identity = x

        out = self.conv1(x)  # First convolutional layer
        out = self.bn1(out)  # Batch normalization
        out = self.relu(out)  # ReLU activation

        out = self.conv2(out)  # Second convolutional layer
        out = self.bn2(out)  # Batch normalization

        if self.downsample is not None:
            identity = self.downsample(x)  # Downsample the input if needed

        out += identity  # Add the residual connection
        out = self.relu(out)  # ReLU activation

        return out


# module to define a Bottleneck residual block for the resnet model
class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution (self.conv2)
    # while the original implementation places the stride at the first 1x1 convolution (self.conv1)
    # according to "Deep residual learning for image recognition" (https://arxiv.org/abs/1512.03385).
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        """
        Bottleneck residual block implementation used in ResNet.

        Args:
        - inplanes (int): Number of input channels.
        - planes (int): Number of output channels.
        - stride (int): Stride value for the convolutional layers (default: 1).
        - downsample (nn.Module, optional): Downsample module (default: None).
        - groups (int): Number of groups for grouped convolution (default: 1).
        - base_width (int): Base width for grouped convolution (default: 64).
        - dilation (int): Dilation rate for dilated convolution (default: 1).
        - norm_layer (Callable[..., nn.Module], optional): Normalization layer (default: nn.BatchNorm2d).
        """
        super(Bottleneck, self).__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        width = int(planes * (base_width / 64.0)) * groups

        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)  # 1x1 convolutional layer
        self.bn1 = norm_layer(width)  # Batch normalization
        self.conv2 = conv3x3(width, width, stride, groups, dilation)  # 3x3 convolutional layer
        self.bn2 = norm_layer(width)  # Batch normalization
        self.conv3 = conv1x1(width, planes * self.expansion)  # 1x1 convolutional layer
        self.bn3 = norm_layer(planes * self.expansion)  # Batch normalization
        self.relu = nn.ReLU(inplace=True)  # ReLU activation function
        self.downsample = downsample  # Downsample module
        self.stride = stride  # Stride value for the convolutional layers

    def forward(self, x: Tensor) -> Tensor:
        """
        Forward pass of the Bottleneck.

        Args:
        - x (Tensor): Input tensor.

        Returns:
        - out (Tensor): Output tensor.
        """
        identity = x

        out = self.conv1(x)  # First 1x1 convolutional layer
        out = self.bn1(out)  # Batch normalization
        out = self.relu(out)  # ReLU activation

        out = self.conv2(out)  # 3x3 convolutional layer
        out = self.bn2(out)  # Batch normalization
        out = self.relu(out)  # ReLU activation

        out = self.conv3(out)  # Second 1x1 convolutional layer
        out = self.bn3(out)  # Batch normalization

        if self.downsample is not None:
            identity = self.downsample(x)  # Downsample the input

        out += identity  # Residual connection
        out = self.relu(out)  # ReLU activation

        return out

# module defining modified RESNET backbone

Now we modify the specific model architecture to include a metadata input branch.

In [None]:
class ResNet(nn.Module):
    def __init__(
            self,
            block: Type[Union[BasicBlock, Bottleneck]],
            layers: List[int],
            num_classes: int = 5,
            img_size: int = 64,
            metadata_size: int = 32,
            zero_init_residual: bool = False,
            groups: int = 1,
            width_per_group: int = 64,
            replace_stride_with_dilation: Optional[List[bool]] = None,
            norm_layer: Optional[Callable[..., nn.Module]] = None,
        ) -> None:
        super(ResNet, self).__init__()

        # If norm_layer is not provided, default to nn.BatchNorm2d
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        
        self.mode = mode
        self.img_size = img_size
        self.metadata_size = metadata_size

        # Check if replace_stride_with_dilation is provided
        if replace_stride_with_dilation is None:
            # If not provided, set it to a default value of [False, False, False]
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                f"or a 3-element tuple, got {replace_stride_with_dilation}"
            )

        self.groups = groups
        self.base_width = width_per_group

        # Initial convolutional layer
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Residual layers (layer1, layer2, layer3, layer4)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])

        # Adaptive average pooling and fully connected layers
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc_meta = nn.Linear(3, metadata_size)
        self.fc_img = nn.Linear(512 * block.expansion, img_size)

        self.fc_combined = nn.Linear(metadata_size +img_size,num_classes)

        self.dropout = nn.Dropout(p=0.5)
        self.soft = nn.Softmax(dim = 1)

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]

    def _make_layer(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        planes: int,
        blocks: int,
        stride: int = 1,
        dilate: bool = False,
    ) -> nn.Sequential:
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation

        # Adjust dilation and stride if dilate is True
        if dilate:
            self.dilation *= stride
            stride = 1

        # Create downsample layer if stride != 1 or number of input channels is different from output channels
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        # Add the first block of the layer with potential downsampling
        layers.append(
            block(
                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
            )
        )
        self.inplanes = planes * block.expansion

        # Add the rest of the blocks in the layer
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor, metadata: Tensor, spectral_data: Tensor) -> Tensor:

        metadata = F.relu(self.fc_meta(metadata))
        

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)  # First residual layer
        x = self.layer2(x)  # Second residual layer
        x = self.layer3(x)  # Third residual layer
        x = self.layer4(x)  # Fourth residual layer

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc_img(x)

        total_length =  self.img_size + self.metadata_size

        combined_features = torch.cat((x, metadata), dim=1)
        combined_features = self.dropout(combined_features)
        x = self.fc_combined(combined_features)
        return x

    def forward(self, x: Tensor, metadata: Tensor, spectral_data: Tensor) -> Tensor:
        """
        Forward pass of the ResNet model.

        Args:
            x (Tensor): Input image tensor.
            metadata (Tensor): Metadata tensor.
            spectral_data (Tensor): Spectral data tensor.

        Returns:
            Tensor: Output tensor.
        """
        return self._forward_impl(x, metadata, spectral_data)


# function to train the revised Resnet model
def resnet_new(block: Type[Union[BasicBlock, Bottleneck]],
               layers: List[int],
               pretrained: bool = False,
               num_classes: int = 5,
               metadata_size: int = 32,
               img_size: int = 64,
               progress: bool = True,
               **kwargs: Any) -> ResNet:
    """
    Create a new ResNet model.

    Args:
        block (Type[Union[BasicBlock, Bottleneck]]): Type of the residual block (BasicBlock or Bottleneck).
        layers (List[int]): List specifying the number of blocks in each layer of the network.
        pretrained (bool): Whether to load a pretrained ResNet model. Default is False.
        num_classes (int): Number of output classes. Default is 17.
        metadata_size (int): Size of the metadata input. Default is 32.
        img_size (int): Size of the image input. Default is 64.
        spectral_size (int): Size of the spectral data input. Default is 32.
        progress (bool): Whether to display a progress bar when downloading pretrained weights. Default is True.
        **kwargs (Any): Additional keyword arguments to pass to the ResNet constructor.

    Returns:
        ResNet: ResNet model.
    """
    if pretrained:
        # Load a pretrained ResNet18 model
        model = resnet.resnet18(pretrained=True, progress=progress)
        # Update the final fully connected layer for the desired number of classes
        model.fc = nn.Linear(model.fc.in_features, num_classes)
        return model

    # Update the kwargs dictionary with the specified parameters
    kwargs['metadata_size'] = metadata_size
    kwargs['img_size'] = img_size
    kwargs['mode'] = mode
    kwargs['block'] = block
    kwargs['layers'] = layers

    # Create a new ResNet model with modified parameters
    model = ResNet(**kwargs)
    return model

## Create and Load Pretrained Model
Before we can train the model, we first have to define the model architecture and load some pretrained weight.  We use pytorch resnet18 model and it's corresponding imagenet pretrained weights.  Using pretrained weights reduces training speed and can also improve final model performance if using limited training data.

In [54]:
from torchvision.models import resnet18, ResNet18_Weights, get_weight
from tqdm import tqdm
import torch 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = resnet18(num_classes = 5)
loaded_state_dict = torch.hub.load_state_dict_from_url("https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth")
current_model_dict = model.state_dict()
new_state_dict={k:v if v.size()==current_model_dict[k].size()  else  current_model_dict[k] for k,v in zip(current_model_dict.keys(), loaded_state_dict.values())}
model.load_state_dict(new_state_dict, strict = False)
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## Training Hyperparameters
Before we start training, there are a few training parameters we need to define.  First is the number of training epochs, which is how many times we use the entire training set.  One epoch of training passes once we've gone through the training set once.  In each epoch, a training step is called an iteration, where N images are loaded at a time based on batch size.  
Another important hyperparameter is learning rate and learning rate schedule.  The learning rate determines how fast the model gets training.  Having a too large or too small learning rate can drastically affect final model performance.  Typically, as training progresses, learning rate is decreased through a learning rate scheduler.  In this case, at predefined epoch points, the learning rate is multiplied by gamma (<1).
Finally, we have to define what loss function we use for training.  For simple classification, we choose cross entropy loss.

In [55]:
import torch.nn as nn
num_epochs = 50
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 15, gamma=0.2)
criterion = nn.CrossEntropyLoss()

## Training the model
Now we train the model and also evaluate on the validation set at each epoch.  If validation accuracy at current epoch is better than previous epochs, the current model weights are saved as best_model.pth.  Finally, at the end of training, the final model weights are also saved as final_model.pth.  

In [60]:
import copy

best_acc = 0

for epoch in range(num_epochs):
    # Training phase
    model.train()
    running_res = []

    running_loss = 0.0
    running_corrects = 0
    running_corr = [0.0, 0.0, 0.0, 0.0, 0.0]
    running_total = [0.0, 0.0, 0.0, 0.0, 0.0]
    for images, imagename, labels in tqdm(train_loader):
        images = images.to(device)
        labels = labels.to(device)

    
        # zero the parameter gradients
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            output = model(images)#inputs)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            
        # statistics
        _, preds = torch.max(output, 1)
        running_loss += loss.item() * images.size(0)
        running_corrects += torch.sum(preds == labels.data)

        for i in range(0, len(preds)):
            if labels.data[i].cpu().detach().numpy() == 3:
                count_3 += 1

            if preds[i] == labels.data[i]:
                running_corr[int(labels.data[i].cpu().detach().numpy())] += 1.0
            running_total[int(labels.data[i].cpu().detach().numpy())] += 1.0
    scheduler.step()
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = 100.0 * running_corrects / len(train_loader.dataset)
    running_res = [100.0 * i / max(1,j) for i, j in zip(running_corr, running_total)]
    print(running_res)
    print("{} Loss: {:.4f} Average Accuracy: {:.4f}".format("train", epoch_loss, epoch_acc))


    # Validation phase
    model.eval()
    running_res = []

    running_loss = 0.0
    running_corrects = 0
    running_corr = [0.0, 0.0, 0.0, 0.0, 0.0]
    running_total = [0.0, 0.0, 0.0, 0.0, 0.0]
    for images, imagename, labels in tqdm(val_loader):
    
        images = images.to(device)
        labels = labels.to(device)
    
        # zero the parameter gradients
        optimizer.zero_grad()
        with torch.set_grad_enabled(False):
            output = model(images)#inputs)
            
        # statistics
        _, preds = torch.max(output, 1)
        running_loss += loss.item() * images.size(0)
        running_corrects += torch.sum(preds == labels.data)

        for i in range(0, len(preds)):
            if labels.data[i].cpu().detach().numpy() == 3:
                count_3 += 1

            if preds[i] == labels.data[i]:
                running_corr[int(labels.data[i].cpu().detach().numpy())] += 1.0
            running_total[int(labels.data[i].cpu().detach().numpy())] += 1.0
    scheduler.step()
    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc = 100.0 * running_corrects / len(val_loader.dataset)
    running_res = [100.0 * i / max(1,j) for i, j in zip(running_corr, running_total)]
    print(running_res)
    print("{} Loss: {:.4f} Average Accuracy: {:.4f}".format("validation", epoch_loss, epoch_acc))
    if(epoch_acc > best_acc):
        print("saving best model")
        best_acc = epoch_acc
        best_model_wts = copy.deepcopy(model.state_dict())
        res = running_res.copy()
        torch.save(model.state_dict(), 'best_model.pth')

torch.save(model.state_dict(), 'final_model.pth')


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.92it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 20.00it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000
saving best model)


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.13it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.61it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.13it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.61it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.13it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.39it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 20.00it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.05it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.13it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 20.00it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.89it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.18it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.08it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.23it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.29it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.49it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.18it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.71it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.29it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 20.00it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.93it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.24it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.29it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.18it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.71it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.86it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.29it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.61it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.08it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.23it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.71it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.49it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.89it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.23it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.51it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.51it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.51it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.51it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.35it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.54it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.70it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.33it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.99it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.23it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.66it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.95it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.99it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.18it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.66it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.54it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.50it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 17.24it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.05it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 11.36it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.50it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 16.95it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.12it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 15.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  7.14it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.08it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.18it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.08it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.18it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.93it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.61it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.09it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.23it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.49it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.18it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.29it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.49it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.23it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.99it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.08it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.23it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.49it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 15.87it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 14.08it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 18.52it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 13.70it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
train Loss: 0.0000 Average Accuracy: 100.0000


100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 19.23it/s]


[0.0, 100.0, 0.0, 0.0, 0.0]
validation Loss: 0.0000 Average Accuracy: 100.0000
