# Классификация рыб и котов

In [91]:
import torchvision, torch
from torchvision import transforms

In [92]:
train_data_path = './Fish-vs-Cats/train/'

transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] )
])

train_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=transforms)

In [93]:
len(train_data), train_data[0]

(800,
 (tensor([[[-0.6965, -0.6965, -0.6109,  ..., -0.2684, -0.3027, -0.4397],
           [-0.6965, -0.6794, -0.5938,  ..., -0.3027, -0.2856, -0.3883],
           [-0.6281, -0.5938, -0.5767,  ..., -0.2856, -0.3198, -0.3369],
           ...,
           [-0.3198, -0.3198, -0.2513,  ..., -1.1760, -0.8164, -0.4911],
           [-0.3712, -0.3027, -0.2856,  ..., -1.4329, -1.2617, -0.8335],
           [-0.3027, -0.2342, -0.2856,  ..., -1.4500, -1.4500, -1.2617]],
  
          [[-0.6702, -0.7052, -0.6352,  ..., -0.2500, -0.2500, -0.4076],
           [-0.6877, -0.6877, -0.6001,  ..., -0.2325, -0.2150, -0.3550],
           [-0.6176, -0.6001, -0.5826,  ..., -0.1975, -0.2150, -0.2675],
           ...,
           [-0.1099, -0.1099, -0.0224,  ..., -1.2654, -0.9678, -0.8277],
           [-0.1625, -0.0749, -0.0749,  ..., -1.4230, -1.3354, -1.0028],
           [-0.1450, -0.0574, -0.0924,  ..., -1.4230, -1.4580, -1.3354]],
  
          [[-0.3927, -0.3927, -0.3230,  ...,  0.0082, -0.0092, -0.2010],
     

In [94]:
val_data_path = './Fish-vs-Cats/val/'
val_data = torchvision.datasets.ImageFolder(root=val_data_path, transform=transforms)

In [95]:
len(val_data)

108

In [96]:
test_data_path = './Fish-vs-Cats/test/'
test_data = torchvision.datasets.ImageFolder(root=test_data_path, transform=transforms)

In [97]:
len(test_data)

160

In [98]:
batch_size = 64
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size)
test_data_loader = torch.utils.data.DataLoader(test_data, batch_size)

In [99]:
class simple_nn(torch.nn.Module):
    def __init__(self):
        super(simple_nn, self).__init__()
        self.fc1 = torch.nn.Linear(64*64*3, 84)
        self.fc2 = torch.nn.Linear(84, 50)
        self.fc3 = torch.nn.Linear(50, 2)

    def forward(self, x):
        x = x.view(-1, 64*64*3)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        return x

In [100]:
net = simple_nn()

In [101]:
import torch.optim as optim
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [102]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device="cpu"):
    for epoch in range(1, epochs+1):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0 
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets) 
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(torch.nn.functional.softmax(output, dim=1), dim=1)[1], targets)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,
        valid_loss, num_correct / num_examples))

In [104]:
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")

net.to(device)

train(net, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=20, device=device)

Epoch: 1, Training Loss: 1.68, Validation Loss: 2.10, accuracy = 0.75
Epoch: 2, Training Loss: 1.22, Validation Loss: 2.03, accuracy = 0.76
Epoch: 3, Training Loss: 0.79, Validation Loss: 2.33, accuracy = 0.66
Epoch: 4, Training Loss: 0.90, Validation Loss: 1.74, accuracy = 0.72
Epoch: 5, Training Loss: 0.63, Validation Loss: 1.82, accuracy = 0.71
Epoch: 6, Training Loss: 0.56, Validation Loss: 1.82, accuracy = 0.69
Epoch: 7, Training Loss: 0.55, Validation Loss: 1.74, accuracy = 0.71
Epoch: 8, Training Loss: 0.50, Validation Loss: 1.76, accuracy = 0.70
Epoch: 9, Training Loss: 0.48, Validation Loss: 1.68, accuracy = 0.72
Epoch: 10, Training Loss: 0.45, Validation Loss: 1.70, accuracy = 0.72
Epoch: 11, Training Loss: 0.42, Validation Loss: 1.68, accuracy = 0.72
Epoch: 12, Training Loss: 0.39, Validation Loss: 1.62, accuracy = 0.74
Epoch: 13, Training Loss: 0.37, Validation Loss: 1.63, accuracy = 0.75
Epoch: 14, Training Loss: 0.35, Validation Loss: 1.66, accuracy = 0.69
Epoch: 15, Trai

In [105]:
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES=True

In [108]:
labels = ['cat','fish']

img = Image.open("./Fish-vs-Cats/val/fish/100_1422.JPG") 
img = transforms(img).to(device)
img = torch.unsqueeze(img, 0)

net.eval()
prediction = torch.nn.functional.softmax(net(img), dim=1)
prediction = prediction.argmax()
print(labels[prediction]) 

fish


In [None]:
torch.save(net, "/tmp/simplenet") 
simplenet = torch.load("/tmp/simplenet")    

In [None]:
torch.save(simplenet.state_dict(), "/tmp/simplenet")    
simplenet = simple_nn()
simplenet_state_dict = torch.load("/tmp/simplenet")
simplenet.load_state_dict(simplenet_state_dict)   