In [2]:
import torch
import torchvision
import torch.nn.functional as F
import torch.optim as optim

from torch import nn
from torch.utils import data
from torchvision import transforms
from tqdm import tqdm


In [5]:
class Dataset(object):
    def __getitem__(self):
        raise NotImplementedError
    
    def __len__(self):
        raise NotImplementedError

In [6]:
train_data_path = "./training_set/"
test_data_path = "./test_set/"

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_data = torchvision.datasets.ImageFolder(train_data_path, transform=transform)
test_data = torchvision.datasets.ImageFolder(test_data_path, transform=transform)

# val_data_path = "./val/"
# val_data = torchvision.datasets.ImageFolder(val_data_path, transform=transforms)

In [7]:
batch_size = 128
train_data_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_data_loader = data.DataLoader(test_data, batch_size=batch_size)

In [10]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(49152, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50, 2)

    def forward(self, x):
        x = x.view(-1, 49152)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

simplenet = SimpleNet()

In [11]:
optimizer = optim.Adam(simplenet.parameters(), lr=0.001)

def train(model, optimizer, train_loader, epochs, device):
    model.to(device)
    for epoch in tqdm(range(epochs)):
        training_loss = 0.0
        model.train()

        for batch in train_loader:
            optimizer.zero_grad()
            inputs, target = batch
            inputs= inputs.to(device)
            target = target.to(device)
            outputs = model(inputs)
            loss = nn.CrossEntropyLoss()(outputs, target)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item()

        training_loss /= len(train_loader)

        print("Epoch: {}, Training Loss: {: 2f}".format(epoch, training_loss))

In [190]:
epochs = 15
train(simplenet, optimizer, train_data_loader, epochs, device=torch.device("cuda:0"))

torch.save(simplenet.state_dict(), "./simplenet.pth")


# simplenet_dict = torch.load("./simplenet.pth")
# simplenet.load_state_dict(simplenet_dict)

  7%|▋         | 1/15 [00:19<04:27, 19.09s/it]

Epoch: 0, Training Loss:  1.016718


 13%|█▎        | 2/15 [00:37<04:06, 18.98s/it]

Epoch: 1, Training Loss:  0.710606


 20%|██        | 3/15 [00:56<03:47, 18.94s/it]

Epoch: 2, Training Loss:  0.568944


 27%|██▋       | 4/15 [01:16<03:29, 19.02s/it]

Epoch: 3, Training Loss:  0.518133


 33%|███▎      | 5/15 [01:35<03:10, 19.01s/it]

Epoch: 4, Training Loss:  0.455660


 40%|████      | 6/15 [01:53<02:50, 19.00s/it]

Epoch: 5, Training Loss:  0.408730


 47%|████▋     | 7/15 [02:12<02:31, 19.00s/it]

Epoch: 6, Training Loss:  0.337036


 53%|█████▎    | 8/15 [02:31<02:12, 18.98s/it]

Epoch: 7, Training Loss:  0.289686


 60%|██████    | 9/15 [02:50<01:53, 18.99s/it]

Epoch: 8, Training Loss:  0.255157


 67%|██████▋   | 10/15 [03:10<01:35, 19.03s/it]

Epoch: 9, Training Loss:  0.246616


 73%|███████▎  | 11/15 [03:29<01:16, 19.03s/it]

Epoch: 10, Training Loss:  0.202283


 80%|████████  | 12/15 [03:48<00:57, 19.05s/it]

Epoch: 11, Training Loss:  0.193396


 87%|████████▋ | 13/15 [04:07<00:38, 19.03s/it]

Epoch: 12, Training Loss:  0.191515


 93%|█████████▎| 14/15 [04:25<00:18, 18.97s/it]

Epoch: 13, Training Loss:  0.133549


100%|██████████| 15/15 [04:44<00:00, 18.99s/it]

Epoch: 14, Training Loss:  0.120633





In [16]:
from PIL import Image

simplenet_dict = torch.load("./simplenet.pth")
simplenet.load_state_dict(simplenet_dict)

labels = ["cats", "dogs"]
img = Image.open("./cat.jpg")
# img = transforms(img)
# img = img.unsqueeze(0)
prediction = simplenet(img)
prediction = prediction.argmax()

print(labels[prediction])

<class 'PIL.JpegImagePlugin.JpegImageFile'>


AttributeError: view

In [15]:
corrects = 0
total = 0
avg_list = []
for batch in test_data_loader:
    inputs, targets = batch
    prediction = simplenet(inputs)
    correct = torch.eq(torch.max(F.softmax(prediction, dim=1), dim=1)[
                       1], targets).view(-1)
    correct_num = torch.sum(correct)
    total_instance = correct.size()[0]
    corrects += correct_num.item()
    total += total_instance
    avg_list.append(round(corrects/total*100, 2))
print("Model accuracy on test data: {}%".format(sum(avg_list)/len(avg_list)))


Model accuracy on test data: 67.409375%
