In [1]:
import sys
sys.path.insert(1, '/kaggle/input/du2project-code/project')

In [12]:
import torch
import torchvision
import torch.utils.data as data
import os
from os.path import join
import argparse
import logging
from tqdm import tqdm
from torchvision.utils import save_image
#user import
from data_generator.DataLoader_IpcGans import CACD
from model.GAN import IPCGANs
from utils.io import check_dir,Img_to_zero_center,Reverse_zero_center
from datetime import datetime

In [3]:
TIMESTAMP = "{0:%Y-%m-%d_%H-%M-%S}".format(datetime.now())

In [4]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
    
    
args = {
    'learning_rate' : 1e-4,
    'batch_size' : 96,
    'max_epoches' : 200,
    'val_interval' : 1,
    'save_interval' : 1, 
    'd_iter' : 1,
    'g_iter' : 2,
    'gan_loss_weight' : 75,
    'feature_loss_weight' : 0.5e-4,
    'age_loss_weight' : 30,
    'age_groups' : 5,
    'age_classifier_path' : '/kaggle/input/du2project-code/projekt/checkpoint/pretrain_alexnet/saved_parameters/epoch_20_iter_0.pth',
    'feature_extractor_path' : '/kaggle/input/du2project-code/project/pre_trained/alexnet/alexnet.model-292000.data-00000-of-00001', 
    'checkpoint' : '/kaggle/working/checkpoint/IPCGANS/%s'%(TIMESTAMP),
    'saved_model_folder' : '/kaggle/working/checkpoint/IPCGANS/%s/saved_parameters/'%(TIMESTAMP),
    'saved_validation_folder' : '/kaggle/working/checkpoint/IPCGANS/%s/validation/'%(TIMESTAMP),
    'tensorboard_log_folder' : '/kaggle/working/checkpoint/IPCGANS/%s/tensorboard/'%(TIMESTAMP),
    'list_root' : '/kaggle/input/du2project-code/project/data/cacd2000-lists',
    'data_root' : '/kaggle/input/du2project-data/CACD2000',
}



args = dotdict(args)

In [11]:
check_dir(args.checkpoint)
check_dir(args.saved_model_folder)
check_dir(args.saved_validation_folder)

FileExistsError: [Errno 17] File exists: '/kaggle/working/checkpoint/IPCGANS/2023-01-03_14-35-27'

In [13]:
logger = logging.getLogger("IPCGANS Train")
file_handler = logging.FileHandler(join(args.checkpoint, 'log.txt'), "w")
stdout_handler = logging.StreamHandler()
logger.addHandler(file_handler)
logger.addHandler(stdout_handler)
stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s'))
logger.setLevel(logging.INFO)

In [15]:

transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    Img_to_zero_center()
])
label_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
])
#step4: define train/test dataloader
train_dataset = CACD(
    list_root = args.list_root, 
    data_root= args.data_root, 
    split = "train",
    transforms=transforms, label_transforms=label_transforms
)
test_dataset = CACD(
    list_root = args.list_root, 
    data_root= args.data_root, 
    split = "test",
    transforms=transforms, label_transforms=label_transforms
)
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=args.batch_size,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=args.batch_size,
    shuffle=False
)

In [16]:
model=IPCGANs(lr=args.learning_rate,age_classifier_path=args.age_classifier_path,gan_loss_weight=args.gan_loss_weight,feature_loss_weight=args.feature_loss_weight,age_loss_weight=args.age_loss_weight)
                  #feature_extractor_path=args.feature_extractor_path)

In [2]:
d_optim=model.d_optim
g_optim=model.g_optim

for epoch in range(args.max_epoches):
    pbar = tqdm(enumerate(train_loader,1),total = len(train_loader))
    for idx, (source_img_227,source_img_128,true_label_img,true_label_128,true_label_64,fake_label_64, true_label) in pbar:

        running_d_loss=None
        running_g_loss=None
        n_iter = epoch * len(train_loader) + idx


        #mv to gpu
        source_img_227=source_img_227.cuda()
        source_img_128=source_img_128.cuda()
        true_label_img=true_label_img.cuda()
        true_label_128=true_label_128.cuda()
        true_label_64=true_label_64.cuda()
        fake_label_64=fake_label_64.cuda()
        true_label=true_label.cuda()

        #train discriminator
        for d_iter in range(args.d_iter):
            #d_lr_scheduler.step()
            d_optim.zero_grad()
            model.train(
                source_img_227=source_img_227,
                source_img_128=source_img_128,
                true_label_img=true_label_img,
                true_label_128=true_label_128,
                true_label_64=true_label_64,
                fake_label_64=fake_label_64,
                age_label=true_label
            )
            d_loss=model.d_loss
            running_d_loss=d_loss
            d_loss.backward()
            d_optim.step()

        #visualize params
        """for name, param in model.discriminator.named_parameters():
            writer.add_histogram("discriminator:%s"%name, param.clone().cpu().detach().numpy(), n_iter)"""

        #train generator
        for g_iter in range(args.g_iter):
            #g_lr_scheduler.step()
            g_optim.zero_grad()
            model.train(
                source_img_227=source_img_227,
                source_img_128=source_img_128,
                true_label_img=true_label_img,
                true_label_128=true_label_128,
                true_label_64=true_label_64,
                fake_label_64=fake_label_64,
                age_label=true_label
            )
            g_loss = model.g_loss
            running_g_loss=g_loss
            g_loss.backward()
            g_optim.step()

        """for name, param in model.generator.named_parameters():
            writer.add_histogram("generator:%s" % name, param.clone().cpu().detach().numpy(), n_iter)"""

        format_str = ('step %d/%d, g_loss = %.3f, d_loss = %.3f')
        #logger.info(format_str % (idx, len(train_loader),running_g_loss,running_d_loss))
        pbar.set_postfix({'g_loss': '%.3f' %(running_g_loss),'d_loss': '%.3f' %(running_d_loss) })


        #writer.add_scalars('data/loss', {'G_loss':running_g_loss,'D_loss':running_d_loss}, n_iter)

        # save the parameters at the end of each save interval
    if epoch % args.save_interval == 0:
        model.save_model(dir=args.saved_model_folder,
                         filename='epoch_%d_iter_%d.pth'%(epoch, idx))
        logger.info('checkpoint has been created!')
        print('checkpoint has been created!')

    #val step
    if epoch % args.val_interval == 0:
        save_dir = os.path.join(args.saved_validation_folder, "epoch_%d" % epoch, "idx_%d" % idx)
        check_dir(save_dir)
        for val_idx, (source_img_128, true_label_128) in enumerate(tqdm(test_loader)):
            save_image(Reverse_zero_center()(source_img_128),fp=os.path.join(save_dir,"batch_%d_source.jpg"%(val_idx)))

            pic_list = []
            pic_list.append(source_img_128)
            for age in range(args.age_groups):
                img = model.test_generate(source_img_128, true_label_128[age])
                save_image(Reverse_zero_center()(img),fp=os.path.join(save_dir,"batch_%d_age_group_%d.jpg"%(val_idx,age)))
        logger.info('validation image has been created!')
        print('validation image has been created!')

NameError: name 'model' is not defined