In [10]:
import numpy as np

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from models.travnet import TravNet
from utils.dataloader import CoTDataset

import matplotlib.pyplot as plt

%matplotlib inline

## Initialization

In [11]:
class Object(object):
    pass

params = Object()
# dataset parameters
params.train_data_path  = r'D:\data_valid\data_valid'
# params.valid_data_path  = r'D:\data_valid\data_valid'
params.csv_file         = 'data.csv'
params.preproc          = True  # Vertical flip augmentation
params.compute_stats    = False
params.depth_mean       = 1.295778
params.depth_std        = 3.441738

# training parameters
params.seed             = 230
params.epochs           = 50
params.batch_size       = 16
params.learning_rate    = 1e-4
params.weight_decay     = 1e-5

# model parameters
params.pretrained = True
params.load_network_path = r'D:\data_valid\checkpoints\best_wayfast.pth' 
params.input_size       = (336, 188)
params.output_size      = (336, 188)
params.output_channels  = 1
params.bottleneck_dim   = 256

In [12]:
torch.manual_seed(params.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(params.seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device)

device: cpu


In [13]:
net = TravNet(params).to(device)

# use to load a previously trained network
if params.load_network_path is not None:
    print('Loading saved network from {}'.format(params.load_network_path))
    net.load_state_dict(torch.load(params.load_network_path))

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")

net = torch.nn.DataParallel(net).to(device)
# else:
#     print("Using a single GPU...")
#     net = net.to(device)



Loading saved network from D:\data_valid\checkpoints\best_wayfast.pth


In [14]:
rgb_test = torch.rand([params.batch_size, 3, params.input_size[1], params.input_size[0]]).to(device)
depth_test = torch.rand([params.batch_size, 2, params.input_size[1], params.input_size[0]]).to(device)
test = net(rgb_test, depth_test)
print('test.shape:', test.shape)

test.shape: torch.Size([16, 1, 188, 336])


In [15]:
transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

train_dataset = CoTDataset(params, params.train_data_path, transform)
# valid_dataset = CoTDataset(params, params.valid_data_path, transform)

train_loader = DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True, num_workers=1)
# valid_loader = DataLoader(valid_dataset, batch_size=params.batch_size, shuffle=True, num_workers=1)

print('Loaded %d train images' % len(train_dataset))
# print('Loaded %d valid images' % len(valid_dataset))

Initializing dataset
weights: [0.80546092 1.         1.         1.         0.19453908]
bins: [0.  0.2 0.4 0.6 0.8 1. ]
Loaded 2996 train images


In [16]:
data = train_dataset[0]

## Set up training tools

In [17]:
criterion = torch.nn.L1Loss(reduction='none')
optimizer = torch.optim.Adam(net.parameters(), lr=params.learning_rate, weight_decay=params.weight_decay)

## Train detector

In [18]:
best_val_loss = np.inf
train_loss_list = []
val_loss_list = []
for epoch in range(params.epochs):

    net.train()    
    train_loss = 0.0
    for i, data in enumerate(train_loader):
        data = (item.to(device).type(torch.float32) for item in data)
        color_img, depth_img, path_img, cot_img, weight = data

        pred = net(color_img, depth_img)

        label = cot_img

        loss = weight*criterion(pred*path_img, label)
        loss = torch.mean(loss)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
    train_loss /= len(train_loader)
    train_loss_list.append(train_loss)
        
    if (epoch) % 10 == 0:
        outstring = 'Epoch [%d/%d], Loss: ' % (epoch+1, params.epochs)
        print(outstring, train_loss)
        print('Learning Rate for this epoch: {}'.format(optimizer.param_groups[0]['lr']))
    
    # evaluate the network on the test data
    with torch.no_grad():
        val_loss = 0.0
        net.eval()
        for i, data in enumerate(train_loader):
            data = (item.to(device).type(torch.float32) for item in data)
            color_img, depth_img, path_img, cot_img, weight = data

            pred = net(color_img, depth_img)

            label = cot_img

            loss = weight*criterion(pred*path_img, label)
            loss = torch.mean(loss)

            val_loss += loss.item()
        val_loss /= len(train_loader)
        val_loss_list.append(val_loss)


    plt.figure(figsize=(14,14))
    plt.subplot(1, 3, 1)
    denormalized_color_img = color_img[0].cpu().numpy()
    for i in range(3):
        denormalized_color_img[i] = denormalized_color_img[i] * 0.225 + 0.406
    # Certifique-se de que os valores estejam no intervalo [0, 1]
    denormalized_color_img = np.clip(denormalized_color_img, 0, 1)
    # Exiba a imagem denormalizada
    plt.imshow(np.transpose(denormalized_color_img, (1, 2, 0)))
    plt.imshow(color_img[0].permute(1, 2, 0).cpu().numpy())
    plt.subplot(1, 3, 2)
    plt.imshow(pred[0,0,:,:].detach().cpu().numpy(), vmin=0, vmax=1)
    plt.show(block=False)
    
    # # if (epoch + 1) % 5 == 0:
    #     plt.figure(figsize=(14,14))
    #     plt.subplot(1, 3, 1)
    #     plt.imshow(color_img[0].permute(1, 2, 0).cpu().numpy())
    #     plt.subplot(1, 3, 2)
    #     plt.imshow(pred[0,0,:,:].detach().cpu().numpy(), vmin=0, vmax=1)
    #     plt.show(block=False)
    
    if best_val_loss > val_loss:
        best_val_loss = val_loss
        print('Updating best test loss: %.5f' % best_val_loss)
        torch.save(net.module.state_dict(), r'D:\data_valid\checkpoints\best_wayfast.pth')

torch.save(net.module.state_dict(), r'D:\data_valid\checkpoints\wayfast.pth')