In [1]:
import copy
import random
import time
import os
import re

import torch
import torch.nn as nn
import torch.nn.functional 
import torch.optim 
import torch.utils.data

import torchvision.transforms
import torchvision.datasets

import skimage.io
import skimage.transform
import sklearn.preprocessing

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [2]:
from PIL import Image, ImageEnhance

# Functions

In [3]:
def set_seeds(seed):
    """sets seeds for several used packages for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [4]:
def test_set_seeds():
    """
    Test if the set_seeds function works.
    """
    seed = 42
    set_seeds(seed) # Call the set_seeds function.
    # create random datasets using torch.randint, random.randint, and np.random.randint. 
    x = torch.randint(0, 10, (3, 3))
    y = random.randint(0,100) 
    z = np.random.randint(5, size=(2, 4))
    set_seeds(seed) # Set the same seeds again.
    # Check the random datasets are still the same.
    assert torch.equal(x, torch.randint(0, 10, (3, 3))), "The set_seed function is broken!"
    assert y == random.randint(0,100), "The set_seed function is broken!"
    assert np.array_equal(z, np.random.randint(5, size=(2, 4))), "The set_seed function is broken!"
    return

In [5]:
test_set_seeds()

In [6]:
def encode_column(column):
    """
    takes single columned Pandas DataFrame of categorical data and encodes it
    into array of class binarys
    """
    encoder = sklearn.preprocessing.OneHotEncoder()
    shape_arr = encoder.fit_transform(column).toarray().astype(int)
        
    return list(shape_arr)

In [7]:
def test_encode_column_1():
    """
    Test if the encode_column function can generate correct output.
    """
    labels = pd.read_csv('data/10x_labels_4.csv')  # Import the csv file containing labels.
    # Create two new columns, color and shape.
    new = labels["Description"].str.split(" ", n=1, expand=True)
    input_column_color = new[0].values
    input_column_shape = new[1].values
    # Call the encode_column function to turn the color and shape features into binary codes.
    output_color = encode_column(input_column_color.reshape(-1, 1))
    output_shape = encode_column(input_column_shape.reshape(-1, 1))
    # Expected output of the encode_column function.
    expect_output_color = [[0, 0, 0, 1, 0, 0],
                           [0, 0, 0, 1, 0, 0],
                           [0, 0, 0, 0, 0, 1],
                           [0, 1, 0, 0, 0, 0],
                           [0, 1, 0, 0, 0, 0],
                           [0, 0, 1, 0, 0, 0],
                           [0, 0, 1, 0, 0, 0],
                           [1, 0, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0, 0],
                           [0, 0, 0, 1, 0, 0],
                           [0, 0, 0, 1, 0, 0],
                           [0, 0, 0, 1, 0, 0],
                           [0, 0, 1, 0, 0, 0],
                           [0, 0, 1, 0, 0, 0],
                           [0, 0, 0, 0, 1, 0],
                           [0, 0, 0, 0, 1, 0]]
    expect_output_shape = [[1, 0, 0, 0, 0],
                           [0, 0, 0, 0, 1],
                           [0, 0, 0, 0, 1],
                           [0, 1, 0, 0, 0],
                           [0, 1, 0, 0, 0],
                           [0, 1, 0, 0, 0],
                           [1, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0],
                           [0, 0, 0, 1, 0],
                           [0, 0, 0, 0, 1],
                           [0, 0, 1, 0, 0],
                           [0, 0, 1, 0, 0],
                           [0, 0, 0, 0, 1],
                           [1, 0, 0, 0, 0],
                           [0, 0, 0, 1, 0],
                           [0, 0, 1, 0, 0]]
    # Check if the expected output is the same as the actual output.
    assert np.array_equal(expect_output_color, output_color), "The function encode_column is broken!"
    assert np.array_equal(expect_output_shape, output_shape), "The function encode_column is broken!"
    return


In [8]:
test_encode_column_1()

In [9]:
def test_encode_column_2():
    """
    Test if the encode_column function is responsive to a wrong datatype of the input.
    """
    input_1 = 10
    test1 = False
    try:
        encode_column(input_1)
    except Exception as e:
        assert isinstance(e, ValueError), "Wrong type of error."
        test1 = True
    assert test1 == True, "Test failed! The encode_column function is not responsive to the wrong input datatype 'int'."
    
    labels = pd.read_csv('data/10x_labels_4.csv')
    new = labels["Description"].str.split(" ", n=1, expand=True)
    input_2 = new[0].values
    test2 = False
    try:
        encode_column(input_2)
    except Exception as e:
        assert isinstance(e, ValueError), "Wrong type of error."
        test2 = True
    assert test2 == True, "Test failed! The encode_column function is not responsive to the wrong input datatype '1D array'."
    
    input_3 = [1, 2, 3]
    test3 = False
    try:
        encode_column(input_3)
    except Exception as e:
        assert isinstance(e, ValueError), "Wrong type of error."
        test3 = True
    assert test3 == True, "Test failed! The encode_column function is not responsive to the wrong input datatype 'list'."
    
    input_4 = 'input'
    test4 = False
    try:
        encode_column(input_4)
    except Exception as e:
        assert isinstance(e, ValueError), "Wrong type of error."
        test4 = True
    assert test4 == True, "Test failed! The encode_column function is not responsive to the wrong input datatype 'str'."
    
    input_5 = True
    test5 = False
    try:
        encode_column(input_5)
    except Exception as e:
        assert isinstance(e, ValueError), "Wrong type of error."
        test5 = True
    assert test5 == True, "Test failed! The encode_column function is not responsive to the wrong input datatype 'bool'."
    
    input_6 = 1.4
    test6 = False
    try:
        encode_column(input_6)
    except Exception as e:
        assert isinstance(e, ValueError), "Wrong type of error."
        test6 = True
    assert test6 == True, "Test failed! The encode_column function is not responsive to the wrong input datatype 'float'."
    return

In [10]:
test_encode_column_2()

In [11]:
def prep_data(labels, image_root):
    """
    Takes in raw labels dataframe and converts it into the format
    expected for tenX_dataset class
    """

    #Splitting description column into color and shape columns
    new = labels["Description"].str.split(" ", n=1, expand=True)
    labels.drop(columns=['Description'], inplace=True)
    labels['Color'] = new[0].values
    labels['Shape'] = new[1].values
    
    #Decomposing sample keywords into seperate strings
    sample_names = labels["Sample"].str.split(" ", n=1, expand=False)
    labels['Sample'] = sample_names
    
    #Converting identification into boolean for is/is not plastic
    PLASTICS = ['polystyrene', 'polyethylene','polypropylene','Nylon','ink + plastic','PET','carbon fiber']
    identification = labels['Identification']
    
    for i in range(0,len(identification)):
        if identification[i] in PLASTICS:
            identification[i] = True
        else:
            identification[i] = False

    labels['Identification'] = identification
    labels.rename(columns={'Identification': 'isPlastic'}, inplace=True)
    labels['isPlastic'] = labels["isPlastic"].astype(int)
    
    
    #Encoding shape and color data
    labels['Shape'] = encode_column(labels[['Shape']])
    labels['Color'] = encode_column(labels[['Color']])
    
    labels = add_filenames(labels, image_root)
    
    return labels

In [13]:
def test_prep_data_1():
    """
    Test if the prep_data function can generate correct output.
    """
    # Load the csv file as the input data frame.
    input_df = pd.read_csv('data/10x_labels_5.csv')
    image_dir = 'data/images_10x'
    # Call the prep_data function and get the actual output "result".
    result = prep_data(input_df, image_dir)
    # Load the output data frame from a csv file.
    output_df = pd.read_csv('data/prep_data_output.csv')
    # Modify the format of the output data frame to make it the expected output.
    for i, rowi in output_df['Sample'].iteritems():
        output_df['Sample'].loc[i] = rowi.split(',')
    for j, rowj in output_df['Color'].iteritems():
        output_df['Color'].loc[j] = np.fromstring(rowj, dtype=int, sep=' ')
    for k, rowk in output_df['Shape'].iteritems():
        output_df['Shape'].loc[k] = np.fromstring(rowk, dtype=int, sep=' ')
    # Check if the expected output data frame is the same as the actual output data frame. 
    assert output_df.equals(result), "The prep_data function is broken!"
    return
    

In [14]:
test_prep_data_1()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_single_block(indexer, value, name)
  if sys.path[0] == '':
  


In [15]:
def test_prep_data_2():
    """
    Test if the prep_data function is responsive to a wrong datatype of the input.
    """
    image_dir = 'data/images_10x'
    input_1 = 10
    test1 = False
    try:
        prep_data(input_1, image_dir)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test1 = True
    assert test1 == True, "Test failed! The prep_data function is not responsive to the wrong input datatype 'int'."
    
    input_2 = 1.2
    test2 = False
    try:
        prep_data(input_2, image_dir)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test2 = True
    assert test2 == True, "Test failed! The prep_data function is not responsive to the wrong input datatype 'float'."
    
    input_3 = [1, 2, 3]
    test3 = False
    try:
        prep_data(input_3, image_dir)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test3 = True
    assert test3 == True, "Test failed! The prep_data function is not responsive to the wrong input datatype 'list'."
    
    input_4 = (1, 2, 3)
    test4 = False
    try:
        prep_data(input_4, image_dir)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test4 = True
    assert test4 == True, "Test failed! The prep_data function is not responsive to the wrong input datatype 'tuple'."
    
    input_5 = 'input'
    test5 = False
    try:
        prep_data(input_5, image_dir)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test5 = True
    assert test5 == True, "Test failed! The prep_data function is not responsive to the wrong input datatype 'str'."
    
    input_6 = False
    test6 = False
    try:
        prep_data(input_6, image_dir)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test6 = True
    assert test6 == True, "Test failed! The prep_data function is not responsive to the wrong input datatype 'bool'."
    return

In [16]:
test_prep_data_2()

In [17]:
def add_filenames(labels, image_root):
    """
    Replaces sample column of labels with the actual filename so that the dataset class doesn't have to do that work.
    """
    for i, item in labels['Sample'].iteritems():
        if type(item) != list: 
            raise TypeError("The type of each item in 'Sample' column has to be 'list'.") 
    image_filenames = os.listdir(image_root)
    labels.insert(loc=1, column='File', value=None)
    for index, row in labels.iterrows():
        sample = row['Sample']
        for fname in image_filenames:
            str_id = '^' + ' '.join(row['Sample']) + ' .*'
            result = re.search(str_id, fname)
            if result:
                image_file = result.group()
                assert(os.path.exists('./data/images_10x/' + image_file))
                break
        else:
            image_file = None
        labels.loc[index, 'File'] = image_file
    return labels

In [18]:
def test_add_filenames_1():
    """
    Test if the add_filenames function can generate correct output.
    """
    # Load the input data frame for the add_filenames function. 
    input_df = pd.read_csv('data/10x_labels_5.csv')
    image_root = 'data/images_10x'
    # Prepare the input data frame
    sample_names = input_df["Sample"].str.split(" ", n=1, expand=False)
    input_df['Sample'] = sample_names
    # Call the add_filenames function and get the actual output data frame.
    result = add_filenames(input_df, image_root)
    # Load the output data frame. 
    output_df = pd.read_csv('data/10x_labels_5_output.csv')
    # Modifiy the output data frame to make it expected output. 
    for i, rowi in output_df['Sample'].iteritems():
        output_df['Sample'].loc[i] = rowi.split(',')
    # Check if the expected output data frame is the same as the actual output data frame.
    assert output_df.equals(result), "The add_filenames function is broken!"
    return

In [20]:
def test_add_filenames_2():
    """
    Test if the add_filenames function is responsive to a wrong datatype of the input.
    """
    input_labels = pd.read_csv('data/10x_labels_5.csv')
    image_root = 'data/images_10x'
    test1 = False
    try: 
        add_filenames(input_labels, image_root)
    except Exception as e:
        assert isinstance(e, TypeError), 'Wrong type of error'
        test1 = True
    assert test1 == True, "Test failed! The add_filenames function is not resposive to TypeErorr of each item in the 'Sample' column"
    
    input_1 = 10
    test2 = False
    try:
        add_filenames(input_1, image_root)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test2 = True
    assert test2 == True, "Test failed! The add_filenames function is not responsive to the wrong input datatype 'int'."
    
    input_2 = 1.2
    test3 = False
    try:
        add_filenames(input_2, image_root)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test3 = True
    assert test3 == True, "Test failed! The add_filenames function is not responsive to the wrong input datatype 'float'."
    
    input_3 = [1, 2, 3]
    test4 = False
    try:
        add_filenames(input_3, image_root)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test4 = True
    assert test4 == True, "Test failed! The add_filenames function is not responsive to the wrong input datatype 'list'."
    
    input_4 = (1, 2, 3)
    test5 = False
    try:
        add_filenames(input_4, image_root)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test5 = True
    assert test5 == True, "Test failed! The add_filenames function is not responsive to the wrong input datatype 'tuple'."
    
    input_5 = 'input'
    test6 = False
    try:
        add_filenames(input_5, image_root)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test6 = True
    assert test6 == True, "Test failed! The add_filenames function is not responsive to the wrong input datatype 'str'."
    
    input_6 = False
    test7 = False
    try:
        add_filenames(input_6, image_root)
    except Exception as e:
        assert isinstance(e, TypeError), "Wrong type of error."
        test7 = True
    assert test7 == True, "Test failed! The add_filenames function is not responsive to the wrong input datatype 'bool'."
    return

In [21]:
test_add_filenames_1()

In [22]:
test_add_filenames_2()

# Custom Dataset

In [23]:
class tenX_dataset(torch.utils.data.Dataset):
    """
    Class inherited from torch Dataset. Required methods are, init,
    len, and getitem.
    """
    def __init__(self, labels_frame, image_dir, transform):
        """
        initializes an instance of the class. Here we store 4 variables
        in the class. Calling init just looks like dataset = tenX_dataset(lables, 'image_folder', transform).
        
        labels: altered version of csv file
        image_dir: The file path to the folder the images are in
        image_filenames: A list of all the image file names in the image folder
        transform: A pytorch object. Works like a function. You call transform(x) and it performs
                    a series of operations on x
        """
        self.labels = labels_frame
        self.image_dir = image_dir
        self.image_filenames = os.listdir(self.image_dir)
        self.transform = transform
        

    def __len__(self):
        """Returns the length of the dataset"""
        return len(self.labels)
    
    
    def __getitem__(self, idx):
        """
        Returns a dictionary containing image and image data. Right now
        it looks like: 
        sample = {'image': image, 'plastic': [0], 'shape':[0,0,0,0,0], 'color':[0,0,0,0,0]}
        """
        image_filename = self.labels['File'][idx]
        image = None
             
        if image_filename is not None:
            image_filepath = os.path.join(self.image_dir, image_filename)
            image = skimage.io.imread(image_filepath)
            if self.transform is not None:
                image = self.transform(image)

        sample = {'image': image,
                  'shape': self.labels['Shape'][idx],
                  'color': self.labels['Color'][idx],
                  'plastic': self.labels['isPlastic'][idx]}
  
        return sample

In [24]:
def test_tenX_dataset():
    """
    Test if the class tenX_dataset can generate correct output.
     
    """
    # Load the inputs for tenX_dataset.
    image_dir = 'data/images_10x'
    labels = prep_data(pd.read_csv('data/10x_labels_5.csv'), image_dir)
    # To make the test more simple, define the "transform" as a function that returns 
    # the input directly without doing anything. 
    def transforms(image):
        return image    
    # Create an object tenX.
    tenX = tenX_dataset(labels, image_dir, transforms)
    # check if the class tenX_dataset can generate the correct output by comparing the actual output with the expected output.
    assert len(tenX) == 10, "The len method in class tenX_dataset is broken!"
    assert tenX[1]['image'].size == 1082880, "The getitem method in class tenX_dataset is broken!"
    assert np.array_equal(tenX[1]['shape'], np.array([0, 0, 0, 1])), "The getitem method in class tenX_dataset is broken!"
    assert np.array_equal(tenX[1]['color'], np.array([0, 0, 1, 0])), "The getitem method in class tenX_dataset is broken!"
    assert tenX[1]['plastic'] == 0, "The getitem method in class tenX_dataset is broken!"
    return

In [25]:
test_tenX_dataset()

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy


### Plotting first 20 images of dataset. Obviously getting quite a few duplicates

In [None]:
labels_filepath = 'data/10x_labels_5.csv'
image_dir = 'data/images_10x'
pd_df = pd.read_csv(labels_filepath)
labels = prep_data(pd_df, image_dir)
#Probably wont center crop since all objects are always near middle of this. This will speed up the network
#May want to consider translating images since elsewise our nn is biased to the center
#Also option to change contrast with ImageEnhance https://pythonexamples.org/python-pillow-adjust-image-contrast/
#two ways to normalize. Batch - noramlize wrt dataset mean/std, Indvid - noramlize each image with own mean and std
#

transforms = torchvision.transforms.Compose([
                            torchvision.transforms.ToPILImage(),
                            torchvision.transforms.CenterCrop((300, 350)),
                            torchvision.transforms.RandomRotation((-180,180)),
                            torchvision.transforms.ToTensor()
                                      ])
tenX = tenX_dataset(labels, image_dir, transforms)

for i in range(len(tenX)):
    sample = tenX[i]['image']
    plt.figure(i)
    if sample is not None:
        print(np.shape(sample))
        plt.imshow(sample.T)
    if i>100:
        break

# Things to improve/fix
* Make sure the nonetypes are because the file actually isn't in my folder of images
* Code for normalizing image data
* Image augmentation. Probably want to cut off some of the edges to get rid of number stuff and decrease extraneous information. The think we actually care about is only occupying like 5-10% of the image.

# Start of me trying to plug into cnn

Most of the code came from this tutorial: https://github.com/bentrevett/pytorch-image-classification/blob/master/2_lenet.ipynb

I was just trying to get this to work so I won't understand it as much

In [None]:
image_dir = 'data/images_10x'
labels_frame = labels

#This transform just resizes the images to 3,480,752. So 3 for red green blue then height of 480
#and width of 752. 
transform = torchvision.transforms.Compose([
                            torchvision.transforms.ToPILImage(),
                            torchvision.transforms.Resize((480, 752)),
                            torchvision.transforms.ToTensor()
                                      ])


train_data = tenX_dataset(labels_frame, image_dir, transform = transform)

#### Splitting into train/validation set

In [None]:
VALID_RATIO = 0.9

n_train_examples = int(len(train_data) * VALID_RATIO)
n_valid_examples = len(train_data) - n_train_examples

train_data, valid_data = torch.utils.data.random_split(train_data, 
                                           [n_train_examples, n_valid_examples])

In [None]:
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')

#### Declaring iterator. The thing that will loop through our dataset.

In [None]:
BATCH_SIZE = 5

train_iterator = torch.utils.data.DataLoader(train_data, 
                                 shuffle = True, 
                                 batch_size = BATCH_SIZE)

valid_iterator = torch.utils.data.DataLoader(valid_data, 
                                 batch_size = BATCH_SIZE)

#### The CNN archetecture

In [None]:
class LeNet(nn.Module):
    def __init__(self, output_dim):
        """
        Initializes CNN. Here we just define layer shapes that we call in the forward func
        """
        super().__init__()

        #Convulution layer 1. 
        #3 input channels (for three images Red, Green, Blue)
        #6 output channels (I THINK this means we are applying two different filters to each image
        #3 images, two filters each, we end up with 6 'images')
        #kernel size is I THINK telling the filters took filter each set of 5 pixels into one.
        #So are images will shrink a little as the edges get cut off
        self.conv1 = nn.Conv2d(in_channels = 3, 
                               out_channels = 6, 
                               kernel_size = 5)
        
        #Convultion layer 2. See above
        self.conv2 = nn.Conv2d(in_channels = 6, 
                               out_channels = 12, 
                               kernel_size = 5)
        
        #Linear layers. These probably arent complicated but I don't follow haha
        #I think it turning the 259740 pixel values into 6 values. Then the second layers
        #Turns the 6 into a different 6? and then 6 into 2. I'm not sure why 2 and not 1.
        #Seeing as the output should be a number between 0-1. Closer to 0 = not plastic,
        #closer to 1 = plastic. But I got errors about not having enough classes when
        #I only had 1 output neuron.
        #TBH these linear layers I just changed based on the error messages I got.
        self.fc_1 = nn.Linear(259740, 6)
        self.fc_2 = nn.Linear(6, 6)
        self.fc_3 = nn.Linear(6, 2)

    def forward(self, x):
        """
        Function that performs all the neural network forward calculation i.e.
        takes image data from the input of the neural network to the output
        """

        
        x = self.conv1(x)
    
        #Gonna have to look at tutorial link.
        x = nn.functional.max_pool2d(x, kernel_size = 2)
        
        x = nn.functional.relu(x)
        
        x = self.conv2(x)
                
        x = nn.functional.max_pool2d(x, kernel_size = 2)
        
        x = nn.functional.relu(x)
        
        x = x.view(x.shape[0], -1)
                
        h = x
        
        x = self.fc_1(x)
                
        x = nn.functional.relu(x)

        x = self.fc_2(x)
                
        x = nn.functional.relu(x)

        x = self.fc_3(x)
        
        return x, h

In [None]:
#Instancing model, loss criteria, device to perform calculations on, and optimizer.
OUTPUT_DIM = 1
model = LeNet(OUTPUT_DIM)


criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
optimizer = torch.optim.Adam(model.parameters())

In [None]:
#Telling the model and loss function to do math on whatever device is
model = model.to(device)
criterion = criterion.to(device)

In [None]:
def calculate_accuracy(y_pred, y):
    """
    Function calculate accuracy. See tutorial, may not
    even be accurate for our model but it at least runs
    """
    top_pred = y_pred.argmax(1, keepdim = True)
    correct = top_pred.eq(y.view_as(top_pred)).sum()
    acc = correct.float() / y.shape[0]
    return acc

In [None]:
def train(model, iterator, optimizer, criterion, device):
    """
    Training loop. Takes data through NN calculates loss and adjusts NN. Repeat
    """
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    #Need to add logic to skip iteration if image is None
    for sample in iterator:  
        print('training')
        if sample['image'] is None:
            print('got a None')
            continue
        image = sample['image'].to(device)
        isPlastic = sample['plastic'].to(device)
                
        optimizer.zero_grad()      
        y_pred, what = model(image)

        loss = criterion(y_pred, isPlastic)
        acc = calculate_accuracy(y_pred, isPlastic)
        loss.backward()    
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [None]:
#Here the model is actually trained
EPOCHS = 20

best_valid_loss = float('inf')

for epoch in range(EPOCHS):
    
    start_time = time.monotonic()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, device)

    
    end_time = time.monotonic()

    #epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    #print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    #print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')