In [1]:
import torchvision
import torch
from torch.nn import functional as F
from torchvision import transforms
from torch.utils import data
import torch.optim as optim
from torch import nn
from PIL import Image

In [2]:
transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

In [3]:
train_data_path = "./train/"
train_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=transforms)

In [4]:
val_data_path = "./val/"
val_data = torchvision.datasets.ImageFolder(root=val_data_path, transform=transforms)

In [5]:
# test_data_path = "./test/"
# test_data = torchvision.datasets.ImageFolder(root=test_data_path, transform=transforms)

In [6]:
batch_size = 64
train_data_loader = data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader  = data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
# test_data_loader  = data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

In [7]:
class SimpleNet(nn.Module):

    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(64*64*3, 84)
        self.fc2 = nn.Linear(84, 50)
        self.fc3 = nn.Linear(50,2)

    def forward(self, x):
        x = x.view(-1, 64*64*3)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x))
        return x

model = SimpleNet()

In [8]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

model.to(device)

SimpleNet(
  (fc1): Linear(in_features=12288, out_features=84, bias=True)
  (fc2): Linear(in_features=84, out_features=50, bias=True)
  (fc3): Linear(in_features=50, out_features=2, bias=True)
)

In [9]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

1036628

In [10]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [11]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item()
        training_loss /= len(train_loader)

        model.eval()
        num_correct = 0
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets)
            valid_loss += loss.data.item()
            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(
            epoch, training_loss,
            valid_loss, 
            num_correct / num_examples
        ))

In [12]:
train(model, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, val_data_loader, 20, device)

  x = F.softmax(self.fc3(x))
  correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)


Epoch: 0, Training Loss: 0.63, Validation Loss: 0.55, accuracy = 0.72
Epoch: 1, Training Loss: 0.55, Validation Loss: 0.53, accuracy = 0.74
Epoch: 2, Training Loss: 0.51, Validation Loss: 0.52, accuracy = 0.74
Epoch: 3, Training Loss: 0.47, Validation Loss: 0.51, accuracy = 0.74
Epoch: 4, Training Loss: 0.45, Validation Loss: 0.55, accuracy = 0.70
Epoch: 5, Training Loss: 0.45, Validation Loss: 0.51, accuracy = 0.76
Epoch: 6, Training Loss: 0.43, Validation Loss: 0.51, accuracy = 0.75
Epoch: 7, Training Loss: 0.42, Validation Loss: 0.53, accuracy = 0.73
Epoch: 8, Training Loss: 0.41, Validation Loss: 0.50, accuracy = 0.78
Epoch: 9, Training Loss: 0.40, Validation Loss: 0.50, accuracy = 0.77
Epoch: 10, Training Loss: 0.40, Validation Loss: 0.55, accuracy = 0.71
Epoch: 11, Training Loss: 0.41, Validation Loss: 0.50, accuracy = 0.77
Epoch: 12, Training Loss: 0.39, Validation Loss: 0.49, accuracy = 0.79
Epoch: 13, Training Loss: 0.39, Validation Loss: 0.53, accuracy = 0.73
Epoch: 14, Train

In [13]:
# accuracy = 0.79

In [14]:


# labels = ['cat','fish']

# img = Image.open(FILENAME)
# img = transforms(img)
# img = img.unsqueeze(0)

# prediction = simplenet(img)
# prediction = prediction.argmax()
# print(labels[prediction])

In [15]:
# torch.save(simplenet, "/tmp/simplenet")

In [16]:
# simplenet = torch.load("/tmp/simplenet")

In [17]:
# torch.save(model.state_dict(), PATH)

In [18]:
# simplenet = SimpleNet()
# simplenet_state_dict = torch.load("/tmp/simplenet")
# simplenet.load_state_dict(simplenet_state_dict)