In [1]:
import argparse
import torch
from torchvision import datasets, transforms
import os
import glob
import numpy as np
import cv2
from PIL import Image
from PIL import ImageFile
from torchvision.utils import save_image, make_grid
from torch.utils.data import Dataset, DataLoader
import sys
sys.path.append('../')

from ConvVAE import ConvVAE
from gradcam import GradCAM

In [2]:
class Carla_dataset(Dataset):
    
    def __init__(self, path, transforms=None, start_i=0, end_i=-1):
        
        self.path = path
        self.transforms = transforms
        self.start_i = start_i
        self.end_i = end_i
        
        self.data = self.read_data(path)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        image = self.data[index]
        label = image.split("/")[-1].split(".")[0]
        ImageFile.LOAD_TRUNCATED_IMAGES = True
        image = Image.open(image).convert('RGB')
        if self.transforms:
            image = self.transforms(image)
        return image, label
    
    def read_data(self, path):
        images = glob.glob(os.path.join(path, "*.png"))
        images = sorted(images, key=lambda name: int(name.split('/')[-1][:-4]))
        if self.end_i == -1:
            images = images[self.start_i:]
        else:
            images = images[self.start_i:self.end_i]
        return images

In [3]:
### Save attention maps  ###
def save_cam(image, filename, gcam):
    gcam = gcam - np.min(gcam)
    gcam = gcam / np.max(gcam)
    h, w, d = image.shape
    gcam = cv2.resize(gcam, (w, h))
    gcam = cv2.applyColorMap(np.uint8(255 * gcam), cv2.COLORMAP_JET)
    gcam = np.asarray(gcam, dtype=np.float) + \
        np.asarray(image, dtype=np.float)
    gcam = 255 * gcam / np.max(gcam)
    gcam = np.uint8(gcam)
    cv2.imwrite(filename, gcam)

In [7]:
seed = 1
batch_size = 64
z_dim = 128
n_channel = 3
beta = 1
im_path = "./attention_maps_seg_2/"

In [8]:
cuda = torch.cuda.is_available()
if cuda:
    print('cuda available')
else:
    print("only cpu")

cuda available


In [9]:
cuda = torch.cuda.is_available()
if cuda:
    print('cuda available')
device = torch.device("cuda:0" if cuda else "cpu")

torch.manual_seed(seed)
trans = transforms.Compose([transforms.Resize((80, 160)), transforms.ToTensor()])
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

dataset = Carla_dataset("../Expert_samples_sem", trans, start_i=590, end_i=645)

loader = torch.utils.data.DataLoader(
                                    dataset,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    **kwargs)

model = ConvVAE(n_channel, z_dim, beta).to(device)
model.load_state_dict(torch.load("./weights/segmodel_expert_samples_sem_all.pt", map_location=device))
mu_avg, logvar_avg = 0, 1
gcam = GradCAM(model, target_layer='encoder.6', cuda=cuda) 
test_index=0
for batch_idx, (x, _) in enumerate(loader):
    model.eval()
    x = x.to(device)
    x_rec, mu, logvar = gcam.forward(x)

    model.zero_grad()
    gcam.backward(mu, logvar, mu_avg, logvar_avg)
    gcam_map = gcam.generate()

    ## Visualize and save attention maps  ##
    for i in range(x.size(0)):
        raw_image = x[i] * 255.0
        ndarr = raw_image.permute(1, 2, 0).cpu().byte().numpy()
        im = Image.fromarray(ndarr.astype(np.uint8))
        if not os.path.exists(im_path):
            os.mkdir(im_path)
        im.save(os.path.join(im_path,
                         "{}-{}-origin.png".format(test_index, "carla")))

        file_path = os.path.join(im_path,
                             "{}-{}-attmap.png".format(test_index, "carla"))
        r_im = np.asarray(im)
        save_cam(r_im, file_path, gcam_map[i].squeeze().cpu().data.numpy())
        test_index += 1

cuda available
torch.Size([1, 55, 256, 3, 8])


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if __name__ == '__main__':


In [21]:
for module in model.named_modules():
    print(module)

('', ConvVAE(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
    (5): ReLU()
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
    (7): ReLU()
  )
  (decoder_input): Linear(in_features=128, out_features=6144, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2))
    (1): ReLU()
    (2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): ConvTranspose2d(64, 32, kernel_size=(5, 5), stride=(2, 2))
    (5): ReLU()
    (6): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2))
    (7): Sigmoid()
  )
  (mu): Linear(in_features=6144, out_features=128, bias=True)
  (logvar): Linear(in_features=6144, out_features=128, bias=True)
  (recons_loss): BCELoss()
))
('encoder', Sequential(
  (0): Conv2d(3, 32, kernel_si

In [13]:
import torch.nn as nn
encoderx = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Linear(6272, 1024),
            nn.ReLU()
        )

In [20]:
for module in encoderx.named_modules():
    print(module[0])


0
1
2
3
4
5
