In [29]:
import torch
from torch import nn as nn
from wasserstein_loss import *
from train_on_NDI import *
import torchvision
import time
import collections

In [30]:
inputs = torch.randn((32, 1, 200, 200))
model = torchvision.models.resnet50()
model.conv1 = nn.Conv2d(1, model.conv1.out_channels, model.conv1.kernel_size, model.conv1.stride, 
                        model.conv1.padding, model.conv1.dilation, model.conv1.groups)
model.fc = nn.Linear(model.fc.in_features, 512)
outputs = []
x1 = model.maxpool(model.relu(model.bn1(model.conv1(inputs))))
outputs.append(x1)
x2 = model.layer1(x1)
outputs.append(x2)
x3 = model.layer2(x2)
outputs.append(x3)
x4 = model.layer3(x3)
outputs.append(x4)
x5 = model.layer4(x4)
outputs.append(x5)
x6 = model.fc(model.avgpool(x5).flatten(1))
outputs.append(x6)

In [20]:
for x in outputs:
    print(x.shape)

torch.Size([32, 64, 50, 50])
torch.Size([32, 256, 50, 50])
torch.Size([32, 512, 25, 25])
torch.Size([32, 1024, 13, 13])
torch.Size([32, 2048, 7, 7])
torch.Size([32, 512])


In [31]:
type(model.avgpool)

torch.nn.modules.pooling.AdaptiveAvgPool2d

In [19]:
class TransformNet(nn.Module):
    def __init__(self, size):
        super(TransformNet, self).__init__()
        self.size = size
        self.net = nn.Sequential(nn.Linear(self.size, self.size))

    def forward(self, input):
        out = self.net(input)
        return out / torch.sqrt(torch.sum(out ** 2, dim=1, keepdim=True))

def get_self_pretrain_model(index=1000):
    base_encoder = torchvision.models.resnet50(weights=None)
    base_encoder.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    origin_dim_mlp = base_encoder.fc.in_features
    base_encoder.fc = None
    temp = torch.load(f'./checkpoints/ImageNet_ALL_CHECK_{index}_Epoch.pth')['state_dict']
    state_dict = {}
    for k, v in temp.items():
        if 'encoder_q' in k:
            if 'fc' not in k:
                state_dict['.'.join(k.split('.')[1:])] = v
    base_encoder.load_state_dict(state_dict)
    base_encoder.fc = torch.nn.Linear(origin_dim_mlp, 512)
    return base_encoder

In [25]:
class simple_NN(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc = nn.Linear(512, 512)
    
    def forward(self, x):
        return self.fc(x)


In [27]:
max_iter = 10
lam = 1
pro = rand_projections(512, 1024).to('cuda')
inputs_1, inputs_2 = torch.randn((32, 512)).cuda(), torch.randn((32, 512)).cuda()
simple_net = simple_NN()
simple_net.cuda()
first_samples = simple_net(inputs_1)
second_samples = simple_net(inputs_2)
first_samples_detach = first_samples.detach()
second_samples_detach = second_samples.detach()
f = TransformNet(512).to('cuda')
f_op = torch.optim.Adam(f.parameters(), lr=0.0005, betas=(0.5, 0.999))
p = 2

for _ in range(max_iter):
    projections = f(pro)
    cos = cosine_distance_torch(projections, projections)
    reg = lam * cos
    encoded_projections = first_samples_detach.matmul(projections.transpose(0, 1))
    distribution_projections = second_samples_detach.matmul(projections.transpose(0, 1))
    wasserstein_distance = get_wasserstein_distance_final_step(encoded_projections, distribution_projections, p)
    loss = reg - wasserstein_distance
    f_op.zero_grad()
    loss.backward(retain_graph=True)
    f_op.step()

In [14]:
loss_1 == loss_2

tensor(True)

In [3]:
from torch.nn import functional as F

inputs1 = torch.randn((32, 512))
inputs2 = torch.randn((32, 512))
em_q = F.normalize(inputs1, dim=1)
em_k = F.normalize(inputs2, dim=1)
sim_matrix = torch.matmul(em_q, em_k.t())




In [10]:
logpt1 = F.log_softmax(sim_matrix, dim=-1)
logpt1 = torch.diag(logpt1)
loss1 = -logpt1.mean()
logpt2 = F.log_softmax(sim_matrix.T, dim=-1)
logpt2 = torch.diag(logpt2)
loss2 = -logpt2.mean()
loss_1 = loss1 + loss2
print(loss1)
print(loss2)
print(loss_1)

tensor(3.4713)
tensor(3.4712)
tensor(6.9425)


In [11]:
criterion = nn.CrossEntropyLoss()
loss1 = criterion(sim_matrix, torch.arange(0, inputs1.size(0)))
loss2 = criterion(sim_matrix.t(), torch.arange(0, inputs1.size(0)))
loss_2 = loss1 + loss2
print(loss1)
print(loss2)
print(loss_2)

tensor(3.4713)
tensor(3.4712)
tensor(6.9425)


In [None]:
criterion = F.cross_entropy

In [36]:
inputs1 = torch.randn((32, 512)).cuda()
inputs2 = torch.randn((32, 512)).cuda()
transform_net = TransformNet(512)
op_transnet = torch.optim.Adam(transform_net.parameters(), lr=0.0005, betas=(0.5, 0.999))
temp1 = rand_projections(512, 1000)
transform_net.cuda()


TransformNet(
  (net): Sequential(
    (0): Linear(in_features=512, out_features=512, bias=True)
  )
)

In [37]:
dswd = distributional_sliced_wasserstein_distance(inputs1.cuda(), inputs2.cuda(), 1024, transform_net, op_transnet, 2, 10, 1, 'cuda')

In [38]:
dswd

tensor(7.5800, device='cuda:0', grad_fn=<PowBackward0>)

In [27]:
cosine_distance_torch(temp1, temp1)

tensor(0.0362)

In [28]:
transform_net = TransformNet(512)

In [32]:
projections = transform_net(temp1)
reg = 1 * cosine_distance_torch(projections, projections)
encoded_projections = inputs1.matmul(projections.transpose(0, 1))
distribution_projections = inputs2.matmul(projections.transpose(0, 1))
wasserstein_distance = get_wasserstein_distance_final_step(encoded_projections, distribution_projections, 2)


tensor(0.5076, grad_fn=<MulBackward0>)

In [26]:
reg

tensor(0.0662)

In [None]:
def train_moco_return_metrics_top_k(net, train_iter, val_iter, optimizer, epochs, device, tested_parameter, criterion=None,
                                    k_candidates=(10,), scheduler=None):
    # train_metrics = HistoryRecorder(['Train Loss', 'Train Acc', 'Val Loss', 'Val Acc'], [list, dict, list, dict])

    target_tensor = get_CNI_tensor(TARGET_IMAGE, device=device)
    train_loss_record = []
    train_acc_record = {k: [] for k in k_candidates}
    val_loss_record = []
    val_acc_record = {k: [] for k in k_candidates}
    for epoch in range(epochs):
        net.cuda(device)
        total_loss = 0
        training_correct = collections.defaultdict(int)
        training_size = 0
        for origin, target, label in train_iter:
            net.train()
            total_loss += train_batch(net, origin,
                                      target, label, optimizer, device=device)
            net.eval()
            with torch.no_grad():
                for k, correct in zip(k_candidates,
                                      cal_accuracy_top_k(image_pair_matching(net, origin.to(device), target_tensor), label.to(device),
                                                         top_k=k_candidates)):
                    training_correct[k] += correct
                training_size += origin.shape[0]
        if scheduler:
            scheduler.step()
        net.eval()
        with torch.no_grad():
            val_loss = 0
            val_correct = collections.defaultdict(int)
            for origin, target, label in val_iter:
                origin, target, label = origin.cuda(
                    device), target.cuda(device), label.cuda(device)
                # output, labels = net(origin, target, evaluate=True)
                # val_loss += f.cross_entropy(output, labels).item()
                em_q, em_k = net(origin, target)
                sim_matrix = net.get_similarity_matrix(em_q, em_k)
                val_loss += net.compute_loss(sim_matrix).item()
                for k, correct in zip(k_candidates,
                                      cal_accuracy_top_k(image_pair_matching(net, origin, target_tensor), label,
                                                         top_k=k_candidates)):
                    val_correct[k] += correct
        val_acc = {k: correct / origin.shape[0]
                   for k, correct in val_correct.items()}
        train_acc = {k: correct / training_size for k,
                     correct in training_correct.items()}
        train_loss_record.append(total_loss / len(train_iter))
        for k, v in train_acc.items():
            train_acc_record[k].append(v)
        val_loss_record.append(val_loss / len(val_iter))
        for k, v in val_acc.items():
            val_acc_record[k].append(v)
        print(
            f'Epoch {epoch + 1}, Train_Loss {total_loss / len(train_iter)}, Val_loss {val_loss / len(val_iter)}')
        # for k, acc in train_acc.items():
        #     print(f'Train_acc_top_{k} {round(acc, 4)}', end='\t')
        # print()
        for k, acc in val_acc.items():
            print(f'Val_acc_top_{k} {round(acc, 2)}', end='\t')
        print()
    output = normalize_data_format(
        {tuple(tested_parameter): (train_loss_record, train_acc_record, val_loss_record, val_acc_record)})
    return output


def normalize_data_format(data: dict, inner=False):
    result = collections.defaultdict(list)
    for k, v in data.items():
        if isinstance(v, tuple):
            for item in v:
                if isinstance(item, list):
                    result[k].append(np.array(item))
                else:
                    result[k].append(normalize_data_format(item, inner=True))
        elif isinstance(v, list):
            result[k].append(np.array(v))
        elif isinstance(v, dict):
            result[k].append(normalize_data_format(v, inner=True))
    if inner:
        for k, v in result.items():
            if isinstance(v, list) and len(v) == 1:
                result[k] = v[0]
    return result


def train_batch(net, observed, calculated, label, optimizer, criterion=None, device=None):
    if device:
        observed, calculated, label = observed.to(
            device), calculated.to(device), label.to(device)
    em_q, em_k = net(observed, calculated)
    if criterion:
        loss = criterion(em_q, em_k)
    else:
        sim_matrix = net.get_similarity_matrix(em_q, em_k)
        loss = net.compute_loss(sim_matrix)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()


In [None]:
top_k_candidates = (10, 20, 30)
k = 7
temps = 0.7
momentums = 0.99
k_value = 64
parameters = {'epochs_pretrain_model': [400]}
# parameters = {'pretrain_model': ['self_pretrained', 'CEM', 'ImageNet', 'None']}
# parameters = {'pretrain_model': ['CEM']}
train_metrics = HistoryRecorder(['Train Loss', 'Train Acc', 'Val Loss', 'Val Acc'], list(parameters.keys()))
parameters = list(itertools.product(*parameters.values()))

for i, parameter in enumerate(parameters):

    ### custom part to get parameters
    pretrain_model = parameter[0]
    ### END
    
    for j, images in enumerate(k_fold_train_validation_split(ORIGINAL_IMAGE, TARGET_IMAGE, k)):
        train_dataset = SingleChannelNDIDatasetContrastiveLearningWithAug(images, False)
        val_dataset = SingleChannelNDIDatasetContrastiveLearningWithAug(images, True)
        train_iter = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
        val_iter = DataLoader(val_dataset, batch_size=len(val_dataset))

        model = get_self_pretrain_model(index=pretrain_model)
        model = RetrievalModel(model)
        model = model.cuda()
        
        
        device = torch.device('cuda:0')
        optimizer = torch.optim.SGD(model.parameters(), lr=5e-3, momentum=0.9, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=30, eta_min=5e-5, verbose=1)
        start_time = time.time()
        print(f'Parameter Index: {i} / {len(parameters)}, Fold Index: {j} / {k}')
        metrics = train_moco_return_metrics_top_k(model, train_iter, val_iter, optimizer, 30, device,
                                                    tested_parameter=parameter, k_candidates=top_k_candidates, scheduler=scheduler)
        end_time = time.time()
        train_metrics.cal_add(metrics)
train_metrics.cal_divide(k)