In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from PIL import Image
from torch.utils.data import Dataset
import pandas as pd
import h5py
from torchvision.transforms import InterpolationMode
import io

In [2]:
nc = 3
ndf = 64

class Discriminator(nn.Module):
    def __init__(self, ngpu, nc = 3, ndf = 64):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(ndf * 8, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 4, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)

        return output.view(-1, 1).squeeze(1)

In [3]:
def get_transform():
    transform = transforms.Compose([
        transforms.Resize(256, interpolation=InterpolationMode.BICUBIC),
        transforms.CenterCrop(256),
        transforms.RandomRotation(30),  # 随机旋转 -30 到 30 度
        transforms.RandomHorizontalFlip(p=0.5),  # 以 50% 概率水平翻转
        transforms.RandomVerticalFlip(p=0.5),  # 以 50% 概率垂直翻转
        # transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0),  # 随机改变亮度、对比度和饱和度
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform

In [4]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, hdf5_file, transform=None, target_label=1):
        """
        Args:
            csv_file (string): Path to the CSV file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
            target_label (int, optional): Target label to filter and load specific class images.
        """
        self.annotations = pd.read_csv(csv_file)
        self.hdf5_file = h5py.File(hdf5_file, 'r')
        self.transform = transform
        
        # 只保留目标标签的数据
        self.annotations = self.annotations[self.annotations['target'] == target_label]

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        isic_id = self.annotations.iloc[idx]["isic_id"]
        if isic_id in self.hdf5_file:
            image = self.hdf5_file[isic_id]
            # Check if the data is numerical before conversion
            image_data = image[()]
            # 将字节字符串解码为图像
            image = Image.open(io.BytesIO(image_data)).convert("RGB")
                

        if self.transform:
            image = self.transform(image)

        return image, 1

In [5]:
mytransform = get_transform()
dataset = CustomDataset(csv_file="../../data/train-metadata.csv", hdf5_file="../../data/train-image.hdf5", transform=mytransform)
dataset.__len__()

  self.annotations = pd.read_csv(csv_file)


393

In [6]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

In [7]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=16)
ngpu = 1
device = torch.device("cuda:1")
netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)
# for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
data = next(iter(dataloader))
print(data[0].shape)
netD.zero_grad()
real_cpu = data[0].to(device)
batch_size = real_cpu.size(0)
label = torch.full((batch_size,), 1,
                        dtype=real_cpu.dtype, device=device)

output = netD(real_cpu)
print(output.shape)

torch.Size([128, 3, 224, 224])
torch.Size([128])


In [8]:
ngf =64
nz = 100
ngpu = 1
nc = 3
class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf * 4,     ngf*2, 4, 4, 0, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            # #128*128
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            #256*256
            nn.ConvTranspose2d( ngf,      nc, 5, 1, 2, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output

In [9]:
nz = 100
ngf = 64
netG = Generator(ngpu).to(device)
netG.apply(weights_init)
noise = torch.randn(batch_size, nz, 1, 1, device=device)
fake = netG(noise)
print(fake.shape)

torch.Size([128, 3, 256, 256])


In [10]:
ngf = 64
nz = 100
nc = 3
ngpu = 1
device = torch.device("cuda:1")
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)


class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d(ngf * 4,     ngf*2, 4, 4, 0, bias=False),
            nn.BatchNorm2d(ngf*2),
            nn.ReLU(True),
            # #128*128
            nn.ConvTranspose2d(ngf * 2,     ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            #256*256
            nn.ConvTranspose2d( ngf,      nc, 5, 1, 2, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
        else:
            output = self.main(input)
        return output





In [11]:
import os
generators = [net for net in os.listdir() if net.startswith('netG') and net.endswith('.pth')]

# path = "./netG_epoch_30.pth"
for path in generators:
    netG = Generator(ngpu).to(device)
    netG.apply(weights_init)
    if path!= '':
        netG.load_state_dict(torch.load(path))
        print("Loaded model from {}".format(path))
    torch.manual_seed(999)
    for i in range(100):
        fixed_noise = torch.randn(1, nz, 1, 1, device=device)
        # print(fixed_noise)
        fake = netG(fixed_noise)
        vutils.save_image(fake.detach(),
                        f'./images/{path}_{i}.png',
                        normalize=True)



    
# print(netG)

In [12]:
torch.manual_seed(999)
for i in range(10):
    fixed_noise = torch.randn(1, nz, 1, 1, device=device)
    print(fixed_noise)
    fake = netG(fixed_noise)
    vutils.save_image(fake.detach(),
                    '%s/fake_samples_epoch_%03d.png' % ('.', i),
                    normalize=True)

tensor([[[[-0.4475]],

         [[ 0.9223]],

         [[-0.2639]],

         [[ 2.2199]],

         [[-1.7821]],

         [[-1.5525]],

         [[-0.0510]],

         [[ 2.0793]],

         [[ 1.8283]],

         [[ 1.0873]],

         [[ 1.0379]],

         [[-0.6214]],

         [[-0.5439]],

         [[-0.9079]],

         [[ 0.1143]],

         [[ 0.3452]],

         [[ 0.4308]],

         [[ 0.2358]],

         [[ 0.5563]],

         [[ 1.1140]],

         [[ 1.5716]],

         [[ 0.1618]],

         [[ 0.5240]],

         [[ 1.0051]],

         [[-0.5359]],

         [[ 0.4478]],

         [[-0.3239]],

         [[-0.4393]],

         [[ 0.4730]],

         [[ 0.3857]],

         [[ 1.2216]],

         [[-0.1880]],

         [[-0.7251]],

         [[ 1.2665]],

         [[-1.6395]],

         [[ 0.7707]],

         [[-0.7572]],

         [[ 0.3397]],

         [[-0.3407]],

         [[-0.8857]],

         [[-0.5045]],

         [[-0.5756]],

         [[-0.4674]],

         [[