In [65]:
from imports import *

In [85]:
class contrastive(nn.Module) :
    def __init__(self, input_size, out_size, t=2):
        super().__init__()
        self.weight = nn.Linear(input_size, out_size)
        self.loss = nn.CrossEntropyLoss()
        self.t = t
        
    def forward(self, features1, features2) :
        g1, g2 = F.normalize(self.weight(features1)), F.normalize(self.weight(features2))
        dists = (g1 @ g2.T)/self.t
        dists = torch.eye(64, device='cuda')
        print(dists)
        target = torch.arange(dists.shape[0]).to('cuda')
        return self.loss(dists, target)



In [86]:
def train_contrastive(model, optimizer, loss_func, sched, metric_fc, train_dl, val_dl, n_epochs, train_df, val_df,
          train_transforms, val_transforms, save_path, val_first=False, 
          prev_best_info={'val': {'thr': None, 'f1': None}, 'train': {'thr': None, 'f1': None}},
          info_history=[], ep_start=0, half_precision=False):
    if half_precision :
        return train_16(model, optimizer, loss_func, sched, metric_fc, train_dl, val_dl, n_epochs, train_df, val_df,
          train_transforms, val_transforms, save_path, val_first, 
          prev_best_info, info_history, ep_start)
    tr_losses = []
    tr_scores = []
    val_scores = []
    prev_best_f_score = -10

    for ep in tqdm(range(ep_start, ep_start + n_epochs), leave=False):
        # TRAINING
        model.train()
        tr_loss = []
        embs = []
        ys = []
        pbar = tqdm(train_dl, leave=False)
        for imgs, labels in pbar:
            ys.append(labels)
            imgs1 = train_transforms(imgs.to('cuda'))
            imgs2 = train_transforms(imgs.to('cuda'))
            optimizer.zero_grad()
            feature1 = model(imgs1)
            feature2 = model(imgs2)
            
            loss = metric_fc(feature1, feature2)
            
            loss.backward()
            optimizer.step()
            if sched is not None : sched.step()

            tr_loss.append(loss.item())
            pbar.set_description(f"Train loss: {round(np.mean(tr_loss),3)}")
            embs.append(feature1.detach())
        ys = pd.Series(torch.cat(ys, 0).numpy())
        embs = F.normalize(torch.cat(embs, 0))
        
        # compute fsccores
        if prev_best_info['train']['thr'] is None :
            thrs = np.linspace(0.2, 1, 10)
        else :
            thrs = [prev_best_info['train']['thr'] - 0.1, prev_best_info['train']['thr'] - 0.05, prev_best_info['train']['thr'], prev_best_info['train']['thr'] + 0.05, prev_best_info['train']['thr'] + 0.1]
        train_f1s, best_thresh_tr, f1_tr = compute_f1(embs, ys, thrs)
        prev_best_info['train']['thr'], prev_best_info['train']['f1'] = best_thresh_tr, f1_tr

        if ep % 2 == 0:
            path =  save_path + '_ep_{}.pth'.format(ep)
            print('Checkpoint : saved model to {}'.format(path))
            torch.save(model.state_dict(
            ),path)

        # VALIDATION
        model.eval()
        val_loss = 0
        with torch.no_grad():
            pbar = tqdm(val_dl, leave=False)
            embs = []
            for imgs, _ in pbar:
                imgs1 = val_transforms(imgs).to('cuda')
                imgs2 = val_transforms(imgs).to('cuda')
                feature1 = model(imgs1)
                feature2 = model(imgs2)
                
                loss = metric_fc(feature1, feature2)
                val_loss += loss.item()/len(val_dl)
                embs.append(feature1)
            embs = F.normalize(torch.cat(embs, 0))

            # compute fsccores
            if prev_best_info['val']['thr'] is None :
                thrs = np.linspace(0.2, 1, 10)
            else :
                thrs = [prev_best_info['val']['thr'] - 0.1, prev_best_info['val']['thr'] - 0.05, prev_best_info['val']['thr'], prev_best_info['val']['thr'] + 0.05, prev_best_info['val']['thr'] + 0.1]
            val_f1s, best_thresh_val, f1_val = compute_f1(embs, val_df['label_group'], thrs)
            prev_best_info['val']['thr'], prev_best_info['val']['f1'] = best_thresh_val, f1_val

            if f1_val > prev_best_f_score:
                prev_best_f_score = f1_val
                torch.save(model.state_dict(
                ), save_path + '_best.pth'.format(ep))
                print('Saved best model ep {} with f score : {}'.format(
                    ep, f1_val))
        info_history.append(copy.deepcopy(prev_best_info))

        tr_losses.append(tr_loss)
        val_scores.append(f1_val)
        summary = "Ep {}: Train loss {:.4f} | Val f score {:.4f} with thresh {:.2f}, train f score {:.4f} with thresh {:.2f}".format(
            ep, np.asarray(tr_loss).mean(), f1_val, best_thresh_val, f1_tr, best_thresh_tr)
        print(summary)
    return prev_best_info, info_history, (tr_losses, val_scores)

In [87]:
from imports import *
from utils import load_data
from image_train.data import create_dl, ImageDS
from image_train.model import EMBRes
from arcface import ArcMarginProduct, compute_centers
from image_train.train import *
import matplotlib.pyplot as plt
np.random.seed(1337)
device = torch.device('cuda')

In [88]:
df, train_df, val_df, train_labels, val_labels = load_data(train_perc=0.3)

In [89]:
# creating dataloaders
small_images_dir_train = 'data/small_train_images_300/'
small_images_dir_val = 'data/small_train_images_224/'
bs = 64

tr_dl = create_dl(train_df, small_images_dir_train, batch_size=bs)
tr_test_dl = create_dl(train_df, small_images_dir_val, shuffle=False, batch_size=bs)
val_dl = create_dl(val_df, small_images_dir_val, shuffle=False, batch_size=bs)
#full_dl = create_dl(df, small_images_dir, shuffle=False)

In [90]:
vision_model = 'resnet18'
model = timm.create_model(vision_model, pretrained=True, num_classes=0).to('cuda')
train_tfms, val_tfms = get_tfms(crop=224)

In [91]:
metric_fc = contrastive(512, 512, t=2).to(device)

In [92]:
n_epochs, lf, params, optimizer, sched = get_hparams(tr_dl, model, metric_fc, lr=5e-4, n_epochs=15)

In [93]:
loss_hist = []
best_thr_score={'val': {'thr': None, 'f1': None}, 'train': {'thr': None, 'f1': None}}
thr_score_hist=[]
ep_start = 0
save_path = 'data/test'

In [94]:
best_thr_score, thr_score_hist, losses = train_contrastive(model, optimizer, lf, sched, metric_fc, tr_dl, val_dl,
                                               n_epochs, train_df, val_df, train_tfms, val_tfms, 
                                               save_path=save_path, 
                                               prev_best_info=best_thr_score, info_history=thr_score_hist,
                                               ep_start=ep_start)
loss_hist.append(losses)
ep_start += n_epochs

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=164.0), HTML(value='')))

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]], device='cuda:0')


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [13]:
torch.save(metric_fc, 'data/image_models/arcmarg_14ep_0.3.pth')

In [10]:
def plot_hist(history) : 
    train_scores = [info['train']['f1'] for info in history]
    train_thr = [info['train']['thr'] for info in history]
    val_scores = [info['val']['f1'] for info in history]
    val_thr = [info['val']['thr'] for info in history]
    x = range(len(train_scores))
    plt.plot(x, val_scores, label='val_score')
    plt.plot(x, train_scores, label='train_score')
    plt.legend()
    plt.show()
    plt.plot(x, val_thr, label='val_thr')
    plt.plot(x, train_thr, label='train_thr')
    plt.legend()
    plt.show()