In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision import transforms

In [2]:
modelo_pre_treinado = torch.hub.load('pytorch/vision', 'resnet34', pretrained=True)

Using cache found in /home/breno-cavalcanti/.cache/torch/hub/pytorch_vision_master


In [3]:
for name, param in modelo_pre_treinado.named_parameters():
    if("bn" not in name):
        param.requires_grad = False

In [4]:
modelo_pre_treinado

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [5]:
num_classes = 2
modelo_pre_treinado.fc = nn.Sequential(nn.Linear(modelo_pre_treinado.fc.in_features,512),
                                  nn.ReLU(),
                                  nn.Dropout(),
                                  nn.Linear(512, num_classes))

In [6]:
def train(model, optimizer, loss_fn, train_loader, epochs=3, device="cpu"):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)
        print(f'Epoch: {epoch}, Training Loss: {training_loss}')

In [7]:
batch_size = 32
img_dimensions = 224

img_transforms = transforms.Compose([
    transforms.Resize((img_dimensions, img_dimensions)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )
    ])

In [8]:
train_data_path = "./train"
train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms)

In [9]:
num_workers = 6
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda") 
else:
    device = torch.device("cpu")
device

device(type='cuda')

In [11]:
modelo_pre_treinado.to(device)
optimizer = optim.Adam(modelo_pre_treinado.parameters(), lr=0.001)
train(model = modelo_pre_treinado, optimizer = optimizer, loss_fn = torch.nn.CrossEntropyLoss(), train_loader = train_data_loader, device=device)

Epoch: 0, Training Loss: 0.12528927942292734
Epoch: 1, Training Loss: 0.07423590209018176
Epoch: 2, Training Loss: 0.061467851464763204


In [12]:
import os
def find_classes(dir):
    classes = os.listdir(dir)
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    print(classes, class_to_idx)
    return classes, class_to_idx

def make_prediction(model, filename):
    labels, _ = find_classes('./train')
    img = Image.open(filename)
    img = img_test_transforms(img)
    img = img.unsqueeze(0)
    prediction = model(img.to(device))
    print(prediction)
    prediction = prediction.argmax()
    print(labels[prediction])

In [19]:
torch.save(modelo_pre_treinado.state_dict(), "./model.pth")