In [1]:
from exp.nb_data import *
from exp.nb_1model import *
from exp.nb_all import *
from exp.nb_Loss import *

In [2]:
import gc

In [3]:
tfms = [into_rgb, ResizeFixed(128), to_byte_tensor, to_float_tensor]

In [4]:
il = ImageList.from_files(path,image_extensions,recurse = True,tfms = tfms)

ll = LabeledList.label_none(il, path/'list_attr_celeba.csv')

dl = DataLoader(ll, batch_size = 32,num_workers= 1)

In [5]:
class DataBunch():
    def __init__(self,train_dl,valid_dl = None):
        self.train_dl = train_dl
#         self.valid_dl = valid_dl
#         self.c =self.train_dl.dataset.y.max().item() + 1

    @property
    def train_ds(self): return self.train_dl.dataset

#     @property
#     def valid_ds(self): return self.valid_dl.dataset

In [6]:
data = DataBunch(dl)

In [7]:
class Runner1():
    def __init__(self, cbs=None, cb_funcs=None):
        cbs = listify(cbs)
        for cbf in listify(cb_funcs):
            cb = cbf()
            setattr(self, cb.name, cb)
            cbs.append(cb)
        self.stop,self.cbs = False,[TrainEvalCallback()]+cbs

    @property
    def opt(self):       return self.learn.opt
    @property
    def model(self):     return self.learn.model
    @property
    def loss_func(self): return self.learn.loss_func
    @property
    def data(self):      return self.learn.data

    def one_batch(self, xb, yb):
        try:
            self.xb,self.yb = xb,yb
            self('begin_batch')
            self.pred = self.model(self.xb)
            self('after_pred')
            self.loss = self.loss_func(self.model,self.pred, self.yb)
            self('after_loss')
            if not self.in_train: return
            self.loss.backward()
            self('after_backward')
            self.opt.step()
            self('after_step')
            self.opt.zero_grad()
        except CancelBatchException: self('after_cancel_batch')
        finally: self('after_batch')

    def all_batches(self, dl):
        self.iters = len(dl)
        try:
            for xb,yb in dl: self.one_batch(xb, yb)
        except CancelEpochException: self('after_cancel_epoch')

    def fit(self, epochs, learn):
        self.epochs,self.learn,self.loss = epochs,learn,tensor(0.)

        try:
            for cb in self.cbs: cb.set_runner(self)
            self('begin_fit')
            for epoch in range(epochs):
                self.epoch = epoch
                if not self('begin_epoch'): self.all_batches(self.data.train_dl)

                with torch.no_grad():
                    if not self('begin_validate'): self.all_batches(self.data.valid_dl)
                self('after_epoch')

        except CancelTrainException: self('after_cancel_train')
        finally:
            self('after_fit')
#             self.learn = None

    def __call__(self, cb_name):
        res = False
        for cb in sorted(self.cbs, key=lambda x: x._order): res = cb(cb_name) or res
        return res

In [8]:
# def kl_loss(attr,model,pred,targ):
#     e = getattr(model,attr)
#     return (-0.5*torch.sum(1 + e.log_var - torch.pow(e.mean,2) - torch.exp(e.log_var), axis = 1)).sum()/int(pred.shape[0])
    

# def total_loss(r_loss_factor,attr,model,pred,targ):
#     return r_loss_factor*F.mse_loss(pred,targ) + kl_loss(attr,model,pred,targ)

# class Total_loss():
#     def __init__(self, r = 10000, attr = 'enc'):
#         self.r = r
#         self.a = attr
#         self.m_loss, self.kl_loss = 0.,0.
        
#     def __call__(self,model,pred,targ):
#         e = getattr(model,self.a)
#         self.m_loss = F.mse_loss(pred,targ)
#         self.kl_loss = (-0.5*torch.sum(1 + e.log_var - torch.pow(e.mean,2) - torch.exp(e.log_var), axis = 1)).mean()
#         return self.r*self.m_loss + self.kl_loss

In [9]:
class TrainEvalCallback(Callback):
    _order = 1
    def begin_fit(self):
        self.run.n_epoch, self.run.n_iters = 0.,0.


    def begin_epoch(self):
        self.model.train()
        self.run.in_train = True
        self.run.n_epoch = self.epoch

    def begin_batch(self):
        if self.run.in_train:
            self.run.n_epoch += 1/self.iters
            self.run.n_iters += 1

    def begin_validate(self):
        return True

class CudaCallback(Callback):
    _order = 30
    def begin_fit(self): self.model.cuda()
    def begin_batch(self): self.run.xb, self.run.yb = self.xb.cuda(), self.yb.cuda()

class PrintLossCallback(Callback):
    
    def after_loss(self):
        if not self.n_iters%10:
            print(f'{self.n_iters} iterations -> Perceptual={self.loss_func.ploss},Kl={self.loss_func.kl_loss}, Total={self.loss}')



class ParamSchedulerCallback(Callback):
    def __init__(self,param,sched_func):
        self.param = param
        self.sf = sched_func
#         self.pos = []

    def change_param(self):
        self.po = self.run.n_epoch/self.run.epochs
        for i in self.opt.param_groups:
            i[self.param] = self.sf(self.po)
#             self.pos.append(self.po)

    def after_loss(self):
        if self.in_train:
#             print(self.run.n_epoch,self.pos)
            self.change_param()

    
class RecorderCallback(Callback):
    def begin_fit(self):
        self.lrs = []
        self.losses = []

    def after_loss(self):
        if self.in_train:
            self.lrs.append(self.opt.param_groups[-1]['lr'])
            self.losses.append(self.loss.detach().cpu())

    def plot_lr(self):
        plt.plot(self.lrs)

    def plot_losses(self):
        plt.plot(self.losses)

In [10]:
class TestCallback(Callback):
    def begin_batch(self):
        print(self.n_iters)
        if self.n_iters > 30:
            raise CancelTrainException()

In [11]:
class StateDictCallback(Callback):
    def __init__(self,path):
        self.path = path
        
    def after_batch(self):
        if not self.n_iters%1000:
            torch.save(model.state_dict(), self.path + '/state_dict')


In [12]:
cbs = [CudaCallback,RecorderCallback,PrintLossCallback, 
       partial(StateDictCallback,'C:/Users/iamab/OneDrive/Documents/project'), TestCallback]

In [13]:
def get_model_opt(data, enc_channels,dec_channels,bn=True,z_dim=200,enc_layer= conv_layer,dec_layer= conv_transpose_layer,
                  opt = 'Adam', lr = 0.001,state_dict_path = None,dropout = True,**kwargs):
    model = Variational_Autoencoder(enc_channels,dec_channels,bn = bn,z_dim=z_dim, enc_layer=enc_layer,
                                   dec_layer=dec_layer, dropout = dropout,**kwargs)
    op = getattr(optim, opt)
    x = next(iter(data.train_dl))[0]
    t = model(x)
    del t,x
    gc.collect()
    if state_dict_path is not None:
        model.load_state_dict(torch.load(state_dict_path))
    return model, op(model.parameters(), lr = lr)

In [14]:
model,opt = get_model_opt(data, enc_channels = [32,64,64,64], dec_channels = [64,64,64,32],lr = 0.0005)

In [None]:
model

In [15]:
def get_learn_run(cbfs,model,opt,data,loss_func = F.cross_entropy):
    learn = Learner(model,opt,data,loss_func)
    run = Runner1(cb_funcs= cbfs)
    return learn,run

In [16]:
learn,run = get_learn_run(cbs,model,opt,
                          data,loss_func= Total_loss(blocks))

In [None]:
run.loss_func??

In [17]:
run.fit(3,learn)

0.0


RuntimeError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 4.00 GiB total capacity; 1.12 GiB already allocated; 0 bytes free; 1.15 GiB reserved in total by PyTorch)

In [None]:
def show_image(im, figsize=(3,3)):
    plt.figure(figsize=figsize)
    plt.axis('off')
    plt.imshow(im.permute(1,2,0))

In [None]:
show_image(il[1])

In [None]:
a[0].detach().cpu().shape

In [None]:
a = run.model(il[1][None,:].cuda())

In [None]:
show_image(a[0].detach().cpu())

In [None]:
F.mse_loss(a[0].detach().cpu(), il[1])

In [None]:
a = model(il[18][None,:])

In [None]:
a.shape