In [1]:
import torch.nn as nn
import torchvision
import torch.utils.data
import torch.optim as optim
from models.medium_fully_connected_net import FullyConnectedNetMedium
from models.small_fully_connected_net import FullyConnectedNetSmall
from models.big_fully_connected_net import FullyConnectedNetBig

In [2]:
# Dataset loading
folder_train = torchvision.datasets.ImageFolder(
    root="../train_images",
    transform=torchvision.transforms.Compose([
        torchvision.transforms.Grayscale(),
        torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
        torchvision.transforms.Resize((40, 60)),
        torchvision.transforms.ToTensor()
    ])
)

train_loader = torch.utils.data.DataLoader(
    folder_train,
    batch_size=32,
    shuffle=True
)

In [3]:
# Train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FullyConnectedNetMedium().to(device).train()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.BCEWithLogitsLoss()
num_episodes = 25

best_loss = float("inf")
for episode in range(num_episodes):
    episode_loss = 0

    for images, labels in train_loader:
        # Preprocess
        images, labels = images.to(device), labels.float().to(device).view(-1, 1)

        # Reset gradients
        optimizer.zero_grad()

        # Feed stuff into model
        outputs = model(images)

        # Compute loss and backpropagate
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Accumulate loss
        episode_loss += loss.item()

    print(f"Episode {episode}, loss: {episode_loss}")

    if episode_loss < best_loss:
        best_loss = episode_loss
        torch.save(model.state_dict(), "../trained/medium_fully_connected_net.pth")
        print("\tSaving!")


Episode 0, loss: 51.3785265982151
	Saving!
Episode 1, loss: 23.734922759234905
	Saving!
Episode 2, loss: 17.401051979511976
	Saving!
Episode 3, loss: 15.536154542118311
	Saving!
Episode 4, loss: 12.579902097582817
	Saving!
Episode 5, loss: 11.752529822289944
	Saving!
Episode 6, loss: 10.365159714594483
	Saving!
Episode 7, loss: 9.508501725271344
	Saving!
Episode 8, loss: 8.776970124803483
	Saving!
Episode 9, loss: 7.271652254275978
	Saving!
Episode 10, loss: 7.02912553679198
	Saving!
Episode 11, loss: 6.643298786599189
	Saving!
Episode 12, loss: 5.534543432760984
	Saving!
Episode 13, loss: 5.7732874299399555
Episode 14, loss: 6.973274164367467
Episode 15, loss: 10.94480085466057
Episode 16, loss: 7.6513963947072625
Episode 17, loss: 6.53784887585789
Episode 18, loss: 6.53201734344475
Episode 19, loss: 5.5929541774094105
Episode 20, loss: 6.3347403379157186
Episode 21, loss: 4.643703497946262
	Saving!
Episode 22, loss: 4.717167695518583
Episode 23, loss: 7.724742140155286
Episode 24, lo