In [1]:
import torch.nn as nn
from model import shape_decoding_digits
import matplotlib.pyplot as plt
import time
import datetime
from torch.utils.data import DataLoader
from model.models import *
from data.datasets import *
import torch
import PIL
from matplotlib.pyplot import imshow
%matplotlib inline
import kornia

In [2]:
# Loss functions
n_epochs = 200
batch_size = 30
lr = 0.0002
b1 = 0.5
b2 = 0.999
img_height = 256
img_width = 256
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.MSELoss()
criteirion_contour = torch.nn.MSELoss()
# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 130

# Calculate output of image discriminator (PatchGAN)
patch = (1, img_height // 2 ** 4, img_width // 2 ** 4)

In [3]:
# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()

In [4]:
# Tensor type
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()
    criteirion_contour.cuda()

In [5]:
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)

Discriminator(
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): ZeroPad2d(padding=(1, 0, 1, 0), value=0.0)
    (12): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), bias=False)
  )
)

In [6]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
scheduler_G = torch.optim.lr_scheduler.StepLR(optimizer_G, step_size=3, gamma=0.9)
scheduler_D = torch.optim.lr_scheduler.StepLR(optimizer_D, step_size=3, gamma=0.9)

In [8]:
print('Load train semantic features:')
train_semantic_vec = torch.load('./data/images/images/semantic_features/train_sm_0726_cluster_20.pt')
print(train_semantic_vec.shape)
train_semantic_vec

Load train semantic features:
torch.Size([6000, 1024])


tensor([[-0.4863,  0.3692,  0.1722,  ...,  0.4117,  0.7808, -0.3116],
        [ 0.4454,  0.0824, -0.4267,  ..., -0.5439, -0.4946,  0.3352],
        [ 0.5285,  0.3229, -0.4287,  ...,  0.5722, -0.1412, -0.0980],
        ...,
        [ 0.8580,  0.2367,  0.0675,  ...,  0.0557, -0.2398,  0.2934],
        [ 0.5958,  0.3708, -0.7983,  ...,  0.6728,  0.4681,  0.6259],
        [ 0.2518,  0.7334, -0.6548,  ...,  0.7596,  0.6879,  0.5078]],
       requires_grad=True)

In [9]:
print("load test semantic features:")
# test_semantic_vec = torch.load("./data/images/images/semantic_features/test_sm_0724_cluster_10.pt")
test_semantic_vec = torch.load("./data/images/images/semantic_features/sm_test_0703_1024.pt")

print(test_semantic_vec.shape)
test_semantic_vec

load test semantic features:
torch.Size([1200, 1024])


tensor([[-0.0749,  0.5169, -0.4813,  ...,  0.6790,  0.4429,  0.2776],
        [ 0.1711,  0.6264, -0.7709,  ...,  0.6824, -0.1504,  0.2731],
        [ 0.5564,  0.2277, -0.6691,  ...,  0.0603,  0.4212,  0.1593],
        ...,
        [ 0.6524, -0.3393, -0.1207,  ...,  0.6354,  0.4155, -0.1743],
        [-0.1301, -0.3213, -0.3402,  ...,  0.3772, -0.2399, -0.5038],
        [ 0.4575,  0.1244, -0.4902,  ...,  0.5225, -0.6285, -0.2949]],
       requires_grad=True)

In [10]:
train_semantic_vec = train_semantic_vec.type(Tensor)
print(train_semantic_vec.shape)
test_semantic_vec = test_semantic_vec.type(Tensor)
print(test_semantic_vec.shape)
prev_time = time.time()

torch.Size([6000, 1024])
torch.Size([1200, 1024])


## GAN Training

In [17]:
transforms_ = [
    transforms.Resize((img_height, img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

dataloader = DataLoader(
    ImageDataset("/home/anderson/Reconstructing-Perceptive-Images-from-Brain-Activity-by-Shape-Semantic-GAN/data/images/images/shape_features_groundtruth2", transforms_= transforms_, mode = "train_6000_ordered_grayscale_blurred_invert"),
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
)




In [18]:
timestr = time.strftime("%Y%m%d-%H%M%S")
print(timestr)
for epoch in range(0, n_epochs):
    scheduler_D.step()
    scheduler_G.step()
    for i, batch in enumerate(dataloader):
        # Model inputs
        real_A = batch["B"].type(Tensor)   # 轮廓
        real_B = batch["A"].type(Tensor)   # 真图
        # Adversarial ground truths
        valid = Tensor(np.ones((real_A.size(0), *patch)))
        fake = Tensor(np.zeros((real_A.size(0), *patch)))

        #  Train Generator

        optimizer_G.zero_grad()

        # GAN loss
#         fake_B = generator(real_A, train_semantic_vec[i*batch_size:(i+1)*batch_size])
        
        fake_B = generator(real_A, train_semantic_vec[0:0])
        
        pred_fake = discriminator(fake_B, real_A)

        loss_GAN = criterion_GAN(pred_fake, valid)
        # Pixel-wise loss
        loss_pixel = criterion_pixelwise(fake_B, real_B)
        # Total loss
        loss_G = loss_GAN + lambda_pixel * loss_pixel

        loss_G.backward()

        optimizer_G.step()

        #  Train Discriminator

        optimizer_D.zero_grad()

        # Real loss
        pred_real = discriminator(real_B, real_A)
        loss_real = criterion_GAN(pred_real, valid)

        # Fake loss
        pred_fake = discriminator(fake_B.detach(), real_A)
        loss_fake = criterion_GAN(pred_fake, fake)

        # Total loss
        loss_D = 0.5 * (loss_real + loss_fake)

        loss_D.backward()
        optimizer_D.step()

        #  Log Progress
        batches_done = epoch * len(dataloader) + i
        batches_left = n_epochs * len(dataloader) - batches_done
        time_left = datetime.timedelta(seconds = batches_left * (time.time() - prev_time))
        prev_time = time.time()

        print(
            "\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"
            % (
                epoch,
                n_epochs,
                i,
                len(dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_pixel.item(),
                loss_GAN.item(),
                time_left,
            )
        )
    if epoch % 10 == 0:
        torch.save(generator.state_dict(),"./generator/"+ timestr+"_"+str(epoch)+"_"+str(loss_D.item())+"_invert_no_semantics.pt")
torch.save(generator.state_dict(),"./generator/"+ timestr+"_"+str(epoch)+"_"+str(loss_D.item())+"_invert_no_semantics_final.pt")


In [11]:
transforms_ = [
    transforms.Resize((img_height, img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
test_dataloader = DataLoader(
    ImageDataset("/home/anderson/Reconstructing-Perceptive-Images-from-Brain-Activity-by-Shape-Semantic-GAN/data/images/images/GAN_prediction_20210717-055947_30_39.014461517333984.pt/", transforms_= transforms_, mode = "test_1200"),
    batch_size=30,
    shuffle=False,
    num_workers=0,
)

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [12]:
generator = GeneratorUNet()
generator.load_state_dict(torch.load("./generator/20210728-112828_199_0.00033030787017196417_invert_no_semantics_final.pt"))
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion_GAN.cuda()
    criterion_pixelwise.cuda()
    criteirion_contour.cuda()

## Inference trained generator 

In [None]:
test_semantic_vec = torch.load("./data/sm_features/test_1200_sm_features.pt")
test_semantic_vec = test_semantic_vec.type(Tensor)
print(test_semantic_vec.shape)

In [None]:
transforms_ = [
    transforms.Resize((img_height, img_width), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
test_dataloader = DataLoader(
    ImageDataset("/home/anderson/Reconstructing-Perceptive-Images-from-Brain-Activity-by-Shape-Semantic-GAN/data/images/images/GAN_prediction_20210717-055947_30_39.014461517333984.pt/", transforms_= transforms_, mode = "test_1200"),
    batch_size=30,
    shuffle=False,
    num_workers=0,
)

In [None]:
denorm = kornia.augmentation.Denormalize(torch.tensor((0.5, 0.5, 0.5)), torch.tensor((0.5, 0.5, 0.5)))
count = 0
for j, batch in enumerate(test_dataloader):
    plt.figure()
    real_A = batch["B"].type(Tensor)
    real_B = batch["A"].type(Tensor)
    fake_B = generator(real_A, test_semantic_vec[j*batch_size:(j+1)*batch_size])
    fake_B = fake_B.cpu().data
    real_B = real_B.cpu().data

    for i in range(batch_size):
        count +=1
        fimg = fake_B[i]
        fimg = denorm(fimg)
        fimg = fimg.squeeze(0)
        fimg = transforms.ToPILImage()(fimg)
        
        raw_img = real_B[i]
        raw_img = denorm(raw_img)
        raw_img = raw_img.squeeze(0)
        raw_img = transforms.ToPILImage()(raw_img)

        target = Image.new('RGB', (256 * 2, 256))
        target.paste(raw_img, (0, 0, 256, 256))
        target.paste(fimg, (256, 0, 512, 256))
        target = target.resize((512, 256), Image.ANTIALIAS)
#         target = target.convert('L')
        target = np.asarray(target)

#         plt.imshow((target))
#         plt.show()
        pic = Image.fromarray(target)
        pic.save("results/ImageNet/GAN_0725_prediction_image_cluster20/test_1200_original_vec/reconstructed_img_"+str(j)+"_"+str(i)+".jpg")
#         ax.get_xaxis().set_visible(False)
#         ax.get_yaxis().set_visible(False)
#     plt.tight_layout()
#     plt.savefig("results/ImageNet/reconstructed_img.png")
print("done")