<a href="https://colab.research.google.com/github/MeganT2004/ENG1-T17-Website/blob/main/test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [10]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms

device = ( #preferred GPU usage
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

#hyperperameters
numEpochs = 50
learningRate = 0.01 #changes the model according to errors each time weights are updated


trainTransform = transforms.Compose([
      transforms.RandomRotation(30),
      transforms.RandomResizedCrop(120),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]) #Normalise dataset between -1 and 1

testTransform = transforms.Compose([
      transforms.RandomRotation(30),
      transforms.RandomResizedCrop(120),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])

# Create datasets for training and testing, downloaded if not already
trainSet = torchvision.datasets.Flowers102(root='./data', split = 'train', transform=trainTransform, download=True)
testSet = torchvision.datasets.Flowers102(root='./data', split= "test", transform=testTransform, download=True)


# Create data loaders for training and testing daasets
trainLoader = torch.utils.data.DataLoader(trainSet, batch_size=4, shuffle=True)
testLoader = torch.utils.data.DataLoader(testSet, batch_size=4, shuffle=False)

classes = ("pink primrose",
    "hard-leaved pocket orchid",
    "canterbury bells",
    "sweet pea",
    "english marigold",
    "tiger lily",
    "moon orchid",
    "bird of paradise",
    "monkshood",
    "globe thistle",
    "snapdragon",
    "colt's foot",
    "king protea",
    "spear thistle",
    "yellow iris",
    "globe-flower",
    "purple coneflower",
    "peruvian lily",
    "balloon flower",
    "giant white arum lily",
    "fire lily",
    "pincushion flower",
    "fritillary",
    "red ginger",
    "grape hyacinth",
    "corn poppy",
    "prince of wales feathers",
    "stemless gentian",
    "artichoke",
    "sweet william",
    "carnation",
    "garden phlox",
    "love in the mist",
    "mexican aster",
    "alpine sea holly",
    "ruby-lipped cattleya",
    "cape flower",
    "great masterwort",
    "siam tulip",
    "lenten rose",
    "barbeton daisy",
    "daffodil",
    "sword lily",
    "poinsettia",
    "bolero deep blue",
    "wallflower",
    "marigold",
    "buttercup",
    "oxeye daisy",
    "common dandelion",
    "petunia",
    "wild pansy",
    "primula",
    "sunflower",
    "pelargonium",
    "bishop of llandaff",
    "gaura",
    "geranium",
    "orange dahlia",
    "pink-yellow dahlia?",
    "cautleya spicata",
    "japanese anemone",
    "black-eyed susan",
    "silverbush",
    "californian poppy",
    "osteospermum",
    "spring crocus",
    "bearded iris",
    "windflower",
    "tree poppy",
    "gazania",
    "azalea",
    "water lily",
    "rose",
    "thorn apple",
    "morning glory",
    "passion flower",
    "lotus",
    "toad lily",
    "anthurium",
    "frangipani",
    "clematis",
    "hibiscus",
    "columbine",
    "desert-rose",
    "tree mallow",
    "magnolia",
    "cyclamen",
    "watercress",
    "canna lily",
    "hippeastrum",
    "bee balm",
    "ball moss",
    "foxglove",
    "bougainvillea",
    "camellia",
    "mallow",
    "mexican petunia",
    "bromelia",
    "blanket flower",
    "trumpet creeper",
    "blackberry lily")

TrainIter = iter(trainLoader) #Picking a random assortment of training images
images, labels = next(TrainIter)

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(16, 26, 5)
        self.fc1 = nn.Linear(3146, 1024) #breaking down the image until it is classified into one of the 102 categories
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 102)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
eval = 40
iteration=0

'''#TRAINING BELOW; REMOVE BLOCK COMMENTS TO RUN

for epoch in range(numEpochs):  # loop over the dataset multiple times

    runningLoss = 0.0
    for i, data in enumerate(trainLoader, 0):
        inputs, labels = data
        iteration+=1

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        runningLoss += loss.item()
        if iteration % eval == 0:
            print(f'[epoch: {epoch + 1}/{numEpochs}, step:{i + 1:5d}/{len(trainLoader)}] loss: {runningLoss / 40:.3f}')
            runningLoss = 0.0

print('Finished Training')

PATH = './TrainedNet.pth'
torch.save(net.state_dict(), PATH)

#TESTING STARTS BELOW'''

testIter = iter(testLoader)
images, labels = next(testIter)

net = Net()
net.load_state_dict(torch.load(PATH))
outputs = net(images)

correct = 0
total = 0


with torch.no_grad():
    for data in testLoader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1) #class with the highest prediction selected
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy: {100 * correct // total} %')

#preparing prediction calculation
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

with torch.no_grad():
    for data in testLoader:
        images, labels = data
        outputs = net(images)
        _, predictions = torch.max(outputs, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


     # print accuracy for each class
    for classname, correct_count in correct_pred.items():
      accuracy = 100 * float(correct_count) / total_pred[classname]
      print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

Accuracy: 14 %
Accuracy for class: pink primrose is 55.0 %
Accuracy for class: hard-leaved pocket orchid is 30.0 %
Accuracy for class: canterbury bells is 0.0 %
Accuracy for class: sweet pea is 0.0 %
Accuracy for class: english marigold is 0.0 %
Accuracy for class: tiger lily is 20.0 %
Accuracy for class: moon orchid is 30.0 %
Accuracy for class: bird of paradise is 47.7 %
Accuracy for class: monkshood is 65.4 %
Accuracy for class: globe thistle is 32.0 %
Accuracy for class: snapdragon is 0.0 %
Accuracy for class: colt's foot is 16.4 %
Accuracy for class: king protea is 13.8 %
Accuracy for class: spear thistle is 14.3 %
Accuracy for class: yellow iris is 0.0 %
Accuracy for class: globe-flower is 47.6 %
Accuracy for class: purple coneflower is 15.4 %
Accuracy for class: peruvian lily is 3.2 %
Accuracy for class: balloon flower is 0.0 %
Accuracy for class: giant white arum lily is 13.9 %
Accuracy for class: fire lily is 40.0 %
Accuracy for class: pincushion flower is 28.2 %
Accuracy for 