In [1]:
import os
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from datasets import MyDatasets
from model import Model
from train_utils import training
import matplotlib.pyplot as plt
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

def pre_process(device):
    logger.debug(f'cwd: {os.getcwd()}')
    logger.info(f"Using {device} device")
    torch.device(device)
    
    SEED = 1
    torch.manual_seed(SEED)
    
    IMAGES_DIR = "images"
    TRANSFORMER =  transforms.Compose([
        transforms.ToTensor(),
    ])
    SPRIT_RATIO = 0.8
    BATCH_SIZE = 32
    
    dataset = MyDatasets(directory=IMAGES_DIR, transform=TRANSFORMER)
    logger.info(f"Dataset size: {len(dataset)}")
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len(dataset) * SPRIT_RATIO), len(dataset) - int(len(dataset) * SPRIT_RATIO)], generator = torch.Generator().manual_seed(SEED))
    logger.info(f"Train dataset size: {len(train_dataset)}")
    logger.info(f"Validation dataset size: {len(val_dataset)}")
    
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
    val_dataloader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
    return train_dataloader, val_dataloader
    
def render_graph(x, y):
    plt.plot(x, y)
    plt.show()
    
    
def main():
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    train_dataloader, val_dataloader = pre_process(DEVICE)
    logger.debug(f"Train data size: {len(train_dataloader)}")
    
    model = Model()
    train_loss = []
    val_loss = []
    
    EPOCHS = 10
    LEARNING_RATE = 1e-3
    
    CRITERION = torch.nn.CrossEntropyLoss()
    OPTIMIZER = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    for epoch in range(EPOCHS):
        logger.info(f"Epoch: {epoch}")
        model, train_batch_loss, val_batch_loss = training(model, train_dataloader, val_dataloader, CRITERION, OPTIMIZER, DEVICE)
        train_loss.append(train_batch_loss)
        val_loss.append(val_batch_loss)
        
    
    torch.save(model.state_dict(), 'model.pth')
    logger.info("Model saved")
    # for X, y in train_dataloader:
    #     plt.imshow(X)
    
    
if __name__ == "__main__":
    main()

cwd: /code/janken
Using cpu device
Dataset size: 136
Train dataset size: 108
Validation dataset size: 28
Train data size: 3
Epoch: 0
data: torch.Size([32, 3, 480, 640]) | label: torch.Size([32, 3])
output: torch.Size([32, 3])
train_loss: 1.1209462881088257
data: torch.Size([32, 3, 480, 640]) | label: torch.Size([32, 3])
output: torch.Size([32, 3])
train_loss: 1.176207184791565
data: torch.Size([32, 3, 480, 640]) | label: torch.Size([32, 3])
output: torch.Size([32, 3])
train_loss: 1.0990710258483887
Epoch: 1
data: torch.Size([32, 3, 480, 640]) | label: torch.Size([32, 3])
output: torch.Size([32, 3])
train_loss: 1.1451963186264038
data: torch.Size([32, 3, 480, 640]) | label: torch.Size([32, 3])
output: torch.Size([32, 3])
train_loss: 1.0817409753799438
data: torch.Size([32, 3, 480, 640]) | label: torch.Size([32, 3])
output: torch.Size([32, 3])
train_loss: 1.1139447689056396
Epoch: 2
data: torch.Size([32, 3, 480, 640]) | label: torch.Size([32, 3])
output: torch.Size([32, 3])
train_loss: 1