In [1]:
import numpy as np
import torch
from glob import glob
import random
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.utils import shuffle
from builtins import getattr
from genericpath import isdir
from templates import *
from templates_cls import *
from experiment_classifier import ClsModel
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [2]:
def normalize(cond, mean, std):
    return (cond - mean.to(cond.device)) / std.to(cond.device)

def denormalize(cond, mean, std):
    return (cond * std.to(cond.device)) + mean.to(cond.device)

In [3]:
def cos(a, b):
    a = a.view(a.shape[0], -1)
    b = b.view(b.shape[0], -1)
    a = F.normalize(a, dim=1)
    b = F.normalize(b, dim=1)
    return (a * b).sum(dim=1)

def spherical_interpolation(x0, x1, alpha):
    theta = th.arccos(cos(x0, x1))
    # fix the divid by zero problem with identical ends
    coef = ((th.sin(
        (1 - alpha) * theta) + 1e-8) / (th.sin(theta) + 1e-8))[:, None,
                                                                  None, None]
    a = coef * x0
    b = (th.sin(alpha * theta) /
         (th.sin(theta) + 1e-8))[:, None, None, None] * x1
    return a + b


def sqrt_interpolation(x0, x1, alpha):
    # doesn't work well with identical ends
    return ((1 - alpha) * x0 + (alpha) * x1) / math.sqrt(alpha**2 +
                                                         (1 - alpha)**2)


def linear_interpolation(x0, x1, alpha):
    return ((1 - alpha) * x0 + (alpha) * x1)

In [4]:
device = 'cuda:3'
lat = torch.load(f'checkpoints/ffhq256_autoenc/latent_train.pkl', map_location='cpu')
data_conds = lat['conds']
# data_conds = normalize(lat['conds'], lat['conds_mean'], lat['conds_std'])
print(lat['conds'].shape)

torch.Size([60000, 512])


In [5]:
label_path = '/home/nontawat/shadow_labels/*.txt'
count = 0

shadow_img = []
normal_img = []


for tx in glob(label_path):
    f = open(tx, 'r')
    line = f.readlines()

    for i in range(len(line)):
        cls = line[i].split()[-1]
        idx = int(line[i].split()[0].split('.')[0])
        if cls == '0':
            normal_img.append(data_conds[idx])
        elif cls == '1':
            # print(idx)
            shadow_img.append(data_conds[idx])
            count+=1
            
print(f'TOTAL SHADOW {count}')


TOTAL SHADOW 612


In [6]:
# hard_neg = '/home/nontawat/grad_mag.text'
# file = open(hard_neg, 'r')
# neg_line = file.readlines()
# def alg(line):
#     de = []
#     for i in range(len(line)):
#         de.append(tuple(line[i].split('_')))

#     sde = sorted(de, key=lambda x: x[1], reverse=False)
#     return sde

# sorted_hard_neg = alg(neg_line)
# norm_neg = []
# for i in range(len(sorted_hard_neg)):
#     idx = int(sorted_hard_neg[i][0])
#     norm_neg.append(data_conds[idx])

In [7]:
random.shuffle(normal_img)
sub_norm = normal_img[:len(shadow_img)]
# sub_norm = norm_neg[:len(shadow_img)]
# len(shadow_img)
train_norm = np.stack(sub_norm, axis=0)
train_shadow = np.stack(shadow_img, axis=0)
norm_label = np.zeros(train_norm.shape[0])
shadow_label = np.ones(train_shadow.shape[0])

In [8]:
print(train_norm.shape)
print(train_shadow.shape)
# mnorm = torch.from_numpy(np.mean(train_norm[:2], axis=0)).to(device)
# mshad = torch.from_numpy(np.mean(train_shadow[:2],axis=0)).to(device)
# direc = mnorm - mshad

(612, 512)
(612, 512)


In [9]:
# print(direc)

In [10]:
img = np.concatenate([train_norm, train_shadow], axis=0)
label = np.concatenate([norm_label, shadow_label], axis=0)
# X, Y = shuffle(img, label, random_state=0)

x_train, x_test, y_train, y_test = train_test_split(img, label, train_size=0.95)
print(x_train.shape)
print(y_train.shape)
print(x_test.shape)
print(y_test.shape)

(1162, 512)
(1162,)
(62, 512)
(62,)


In [11]:
reg = LogisticRegression(tol=1e-2, max_iter=500000, verbose=0)
# param = {'penalty': ['l1', 'l2', 'elasticnet', 'none'],
#     'tol': [1e-2, 1e-4, 1e-6, 1e-8],
#     'C':[1e-8, 1e-6, 1e-4, 1e-2, 1, 1.2, 1.4, 2, 3, 4, 10, 100]}
# st = StratifiedKFold(n_splits=10)
# clf = GridSearchCV(reg, param, cv=st)
# clf.fit(img, label)
reg.fit(img, label)

LogisticRegression(max_iter=500000, tol=0.01)

In [12]:
reg.score(img, label)

0.9223856209150327

In [13]:
print(reg.coef_.shape)
print(reg.intercept_)
cls_dict = {}
cls_dict['weight'] = reg.coef_
cls_dict['bias'] = reg.intercept_

(1, 512)
[-1.28457544]


In [14]:
import pickle
with open('cls_weight.pkl', 'wb') as ob:
    pickle.dump(cls_dict, ob, protocol=pickle.HIGHEST_PROTOCOL)

In [24]:
print(data_conds.shape)

torch.Size([60000, 512])


In [1]:
import pickle, torch
pkl_file = open('cls_weight.pkl', 'rb')
cls_weight = pickle.load(pkl_file)

In [2]:
lat = torch.load(f'checkpoints/ffhq256_autoenc/latent_val.pkl', map_location='cpu')
data_conds = lat['conds']
print(cls_weight.keys())

dict_keys(['weight', 'bias'])


In [3]:
print(data_conds.shape)

torch.Size([10000, 512])


In [4]:
ndata = data_conds.numpy()
proj_data = ndata@(cls_weight['weight'].T)
proj_data_bias = proj_data+cls_weight['bias']

In [5]:
with open('val-shadow.txt', 'w') as f:
    for i in range(data_conds.shape[0]):
        f.write(f'{60000+i}.jpg {proj_data[i].item()}')
        f.write('\n')

In [6]:
with open('val-shadowbias.txt', 'w') as f:
    for i in range(data_conds.shape[0]):
        f.write(f'{60000+i}.jpg {proj_data_bias[i].item()}')
        f.write('\n')

In [None]:
reg.score(x_train, y_train)

In [None]:
print(y_test.sum())
reg.score(x_test, y_test)

In [None]:
## LOAD MAIN MODEL
conf = ffhq256_autoenc()
model = LitModel(conf)
state = torch.load(f'/home2/nontawat/diffae_logs/{conf.name}/last.ckpt', map_location='cpu')
model.load_state_dict(state['state_dict'], strict=False)
model.ema_model.eval()
model.ema_model.to(device)

In [None]:
cls_conf = ffhq256_autoenc_cls()
state = torch.load(f'/home2/nontawat/diffae_logs/{cls_conf.name}/last.ckpt',
                    map_location='cpu')
print('latent step:', state['global_step'])

In [None]:
print(state['state_dict'].keys())

In [None]:
# cls_id = CelebAttrDataset.cls_to_id['Smiling']
# mean = state['state_dict']['conds_mean'].to(device)
# std = state['state_dict']['conds_std'].to(device)
# weit = state['state_dict']['classifier.weight'][cls_id][None, :].to(device)
# print(mean.shape)
# print(weit.shape)

In [None]:
lat_val = torch.load(f'checkpoints/ffhq256_autoenc/latent_val.pkl', map_location='cpu')
img = Image.open('/home2/nontawat/ffhq_256/valid/60065.jpg')
tran = torchvision.transforms.Compose([
                            torchvision.transforms.ToTensor(),
                            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                            ])
imt = tran(img)
cond = lat['conds'][65][None, :].to(device)
# xT = model.encode_stochastic(imt[None, :].to(device), cond, T=250)
# ncond = normalize(cond, mean, std)
# ncond = (cond + 0.05 * (torch.from_numpy(reg.coef_).to(device))).float()
# ncond = denormalize(ncond, mean, std)
print(cond.shape)
# print(xT.shape)

In [None]:
reg.predict_proba(cond.cpu().numpy())

In [None]:
# lat_val = torch.load(f'checkpoints/ffhq256_autoenc/latent_val.pkl', map_location='cpu')
# img = Image.open('/home2/nontawat/ffhq_256/valid/60824.jpg')
# tran = torchvision.transforms.Compose([
#                             torchvision.transforms.ToTensor(),
#                             torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
#                             ])
# imt = tran(img)
# print(imt.shape)
# cond = lat_val['conds'][824][None, :].to(device)
# xTT = model.encode_stochastic(imt[None, :].to(device), cond, T=250)
# xT = torch.randn([1, 3, 256, 256]).to(device)
# # ncond = normalize(cond, lat['conds_mean'], lat['conds_std'])
# ncond = cond + direc[None, :]
# # ncond = denormalize(ncond, lat['conds_mean'], lat['conds_std'])
# print(cond.shape)
# print(xT.shape)
# xT = torch.randn([1, 3, 256, 256]).to(device)
# cat = torch.stack([data_conds[2], data_conds[4], data_conds[12], data_conds[13]], dim=0)
# mcat = cat.mean(dim=0).to(device)
# print(mcat.shape)
# ii = model.render(xT, mcat[None, :], T=250)

In [None]:
# im = ii[0].permute(1,2,0).cpu().numpy()
# plt.imshow((im*255).astype(np.uint8))

In [None]:
# cond1 = (data_conds[12]).to(device)
# cond2 = (data_conds[13]).to(device)
# xT = torch.randn([1, 3, 256, 256]).to(device)
# xT2 = torch.randn([1, 3, 256, 256]).to(device)
# with torch.no_grad():
#     imgOri1 = model.render(xT, cond1[None, :], T=250)
#     imgOri2 = model.render(xT2, cond2[None, :], T=250)
# mnorm = denormalize(mnorm, lat['conds_mean'], lat['conds_std'])
# mshad = denormalize(mshad, lat['conds_mean'], lat['conds_std'])

In [None]:
# img_list = []
# alpha = torch.linspace(0, 1, 7)
# for i in range(len(alpha)):
#     Xintp = spherical_interpolation(xT, xT2, alpha[i])
#     cintp = linear_interpolation(cond1[None, :], cond2[None, :], alpha[i])
#     imginp = model.render(Xintp, cintp, T=250)
#     img_list.append(imginp)

In [None]:
# print(imgOri1.shape)

In [None]:
# imgt1 = imgOri1[0].permute(1,2,0).cpu().numpy()
# imgt2 = imgOri2[0].permute(1,2,0).cpu().numpy()
# imgtt = np.concatenate([img_list[i][0].permute(1,2,0).cpu().numpy() for i in range(len(img_list))], axis=1)
# imgh = np.concatenate([imgt1, imgtt, imgt2], axis = 1)
# plt.imshow((imgh*255).astype(np.uint8))

In [None]:
# imgOri = model.render(xT, cond, T=250)
# imO = imgOri[0].permute(1, 2, 0).cpu().numpy()

In [None]:
# print(reg.coef_.shape)
# weit = F.normalize(torch.from_numpy(reg.coef_).to(device), dim=1)

# condS = (cond - 2.5*weit).float()

In [None]:
# plt.subplot(141), plt.imshow((imO*255).astype(np.uint8))

# imgMod = model.render(xT, ncond, T=250)
# imgNom = model.render(xTT, mnorm[None, :], T=250)
# imgShad = model.render(xTT, mshad[None, :], T=250)
# imMd = imgMod[0].permute(1, 2, 0).cpu().numpy()
# imNm = imgNom[0].permute(1, 2, 0).cpu().numpy()
# imSd = imgShad[0].permute(1, 2, 0).cpu().numpy()
# plt.subplot(142), plt.imshow((imMd*255).astype(np.uint8))
# plt.subplot(143), plt.imshow((imNm*255).astype(np.uint8))
# plt.subplot(144), plt.imshow((imSd*255).astype(np.uint8))


In [None]:
# print(train_shadow.shape)
# xH = torch.randn([612, 3, 256, 256]).to(device)

# img = model.render(xH[:8], torch.from_numpy(train_shadow[:8]).to(device), T=250)


In [None]:
# print(img.shape)
# imgt = img.permute(0, 2, 3, 1).cpu().numpy()
# imgh = np.concatenate([imgt[i]  for i in range(8)], axis = 1)
# plt.imshow((imgh*255).astype(np.uint8))