In [1]:
import visdom
from datasets import get_dataset, HyperX
import utils
import numpy as np
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import LambdaLR
import torch.optim as optim
from torch.nn import init
import torch.utils.data as data
#from torchsummary import summary
#from torch.utils.tensorboard  import SummaryWriter
from tensorboardX import SummaryWriter

import math
import os
import datetime
#import joblib
from tqdm import tqdm
import argparse

import models

vis = visdom.Visdom()

Setting up a new session...


In [101]:
img, gt, label_values, ignored_labels, rgb_bands, palette = get_dataset('Salinas', target_folder='/home/oscar/Desktop/Exjobb/Data/')

N_CLASSES = len(label_values)
N_BANDS = img.shape[-1]
IGNORED_LABELS = ignored_labels
PATCH_SIZE = 5
SAMPLING_MODE = 'disjoint'
SAMPLING_PERCENTAGE = 0.05
BATCH_SIZE = 10
UNLABELED_RATIO = 7
EPOCHS = 10

hyperparams = {'patch_size': PATCH_SIZE, 'dataset': 'Salinas', 'ignored_labels': IGNORED_LABELS, 
               'flip_augmentation': True, 'radiation_augmentation': False, 'mixture_augmentation': False,
              'center_pixel': True, 'supervision': 'full'}

if palette is None:
    # Generate color palette
    palette = {0: (0, 0, 0)}
    for k, color in enumerate(sns.color_palette("hls", len(label_values) - 1)):
        palette[k + 1] = tuple(np.asarray(255 * np.array(color), dtype='uint8'))
invert_palette = {v: k for k, v in palette.items()}

def convert_to_color(x):
    return utils.convert_to_color_(x, palette=palette)
def convert_from_color(x):
    return utils.convert_from_color_(x, palette=invert_palette)

train_gt, test_gt = utils.sample_gt(gt, SAMPLING_PERCENTAGE,
                                    mode=SAMPLING_MODE)
print("{} samples selected (over {})".format(np.count_nonzero(train_gt), np.count_nonzero(gt)))

model = models.HamidaEtAl(N_BANDS, N_CLASSES, patch_size=PATCH_SIZE)

train_gt, val_gt = utils.sample_gt(train_gt, 0.95, mode=SAMPLING_MODE)

val_dataset = HyperX(img, val_gt, labeled=True, **hyperparams)
val_loader = data.DataLoader(val_dataset,
                             batch_size=BATCH_SIZE)

train_labeled_gt, train_unlabeled_gt = utils.sample_gt(train_gt, 1/(UNLABELED_RATIO + 1),
                                                       mode=SAMPLING_MODE)
amount_labeled = np.count_nonzero(train_labeled_gt)

train_labeled_dataset = HyperX(img, train_labeled_gt, labeled=True, **hyperparams)
train_labeled_loader = data.DataLoader(train_labeled_dataset, batch_size=BATCH_SIZE,
                                       #pin_memory=True,
                                       shuffle=True, drop_last=True)

train_unlabeled_dataset = HyperX(img, train_unlabeled_gt, labeled=False, **hyperparams)
train_unlabeled_loader = data.DataLoader(train_unlabeled_dataset,
                                         batch_size=BATCH_SIZE*UNLABELED_RATIO,
                                         #pin_memory=True,
                                         shuffle=True, drop_last=True)

iterations = amount_labeled // BATCH_SIZE
total_steps = iterations * EPOCHS

3467 samples selected (over 54129)


  **kwargs)
  arrmean, rcount, out=arrmean, casting='unsafe', subok=False)
  ret, rcount, out=ret, casting='unsafe', subok=False)


In [102]:
np.max(img)

1.0000001

In [103]:
weights = np.zeros(N_CLASSES)
frequencies = np.zeros(N_CLASSES)
train_freq = np.zeros(N_CLASSES)

for c in range(0, N_CLASSES):
    if c in IGNORED_LABELS:
        continue
    frequencies[c] = np.count_nonzero(gt == c)
    train_freq[c] = np.count_nonzero(train_gt == c)

print(train_freq)
print(frequencies)
    
# Normalize the pixel counts to obtain frequencies
frequencies /= np.sum(frequencies)

[  0. 100. 170. 104.  57. 130. 183. 165. 492. 294. 148.  52. 104. 784.
  57. 317.  82.]
[    0.  2009.  3726.  1976.  1394.  2678.  3959.  3579. 11271.  6203.
  3278.  1068.  1927.   916.  1070.  7268.  1807.]


In [104]:
np.median(frequencies[np.nonzero(frequencies)])

0.04329472186812984

In [105]:
# Obtain the median on non-zero frequencies
median = np.median(frequencies[np.nonzero(frequencies)])
weights = median / frequencies
weights[frequencies == 0] = 0.
weights

  This is separate from the ipykernel package so we can avoid doing imports until


array([0.        , 1.16650075, 0.62895867, 1.18598178, 1.68113343,
       0.87509335, 0.59194241, 0.65479184, 0.20792299, 0.37780106,
       0.71491763, 2.19428839, 1.21613908, 2.55840611, 2.19018692,
       0.32244084, 1.29690094])

In [106]:
weights/(np.max(weights))

array([0.        , 0.45594823, 0.24584004, 0.46356275, 0.65710187,
       0.3420463 , 0.23137156, 0.25593741, 0.08127052, 0.14767048,
       0.27943868, 0.8576779 , 0.47535029, 1.        , 0.85607477,
       0.12603192, 0.50691754])

In [107]:
iter_data = enumerate(zip(train_labeled_loader, train_unlabeled_loader))

In [161]:
idx, (data_l, data_u) = next(iter_data)
input_l, target_l = data_l
input_w, input_s = data_u

inputs = torch.cat((input_l, input_w, input_s))

In [163]:
torch.max(input_s)

tensor(0.9215)

In [164]:
model.train()
logits=model(inputs)

In [165]:
torch.max(logits)

tensor(21.9774, grad_fn=<MaxBackward1>)

In [121]:
weights_balance = utils.compute_imf_weights(train_gt, N_CLASSES, IGNORED_LABELS)
weights = torch.from_numpy(weights_balance)
weights = weights.to(torch.float32)

loss_labeled = nn.CrossEntropyLoss(weight=weights)
loss_unlabeled = nn.CrossEntropyLoss(weight=weights, reduction='none')

In [122]:
type(loss_labeled)

torch.nn.modules.loss.CrossEntropyLoss

In [166]:
logits_l = logits[0:BATCH_SIZE]
logits_w, logits_s = logits[BATCH_SIZE:].chunk(2)

sup_loss = loss_labeled(logits_l, target_l)

psuedo_label = torch.softmax(logits_w.detach_(), dim=-1)
max_probs, psuedo_target = torch.max(psuedo_label, dim=-1)
mask = max_probs.ge(0.95).float()

unsup_loss = (loss_unlabeled(logits_s, psuedo_target)*mask).mean()

In [167]:
max_probs[psuedo_target==0]

tensor([])

In [130]:
optimizer = optim.SGD(model.parameters(), lr=0.03, momentum=0.9, nesterov=True)

In [168]:
loss = sup_loss + unsup_loss
loss.backward()

In [169]:
loss

tensor(14.7449, grad_fn=<AddBackward0>)

In [170]:
optimizer.step()

In [171]:
model_test = models.HamidaEtAl(N_BANDS, N_CLASSES, patch_size=PATCH_SIZE)

for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, torch.min(param.data), torch.max(param.data))

conv1.weight tensor(-1.2744) tensor(1.0164)
conv1.bias tensor(-1.7973) tensor(0.0537)
pool1.weight tensor(-0.6922) tensor(0.7958)
pool1.bias tensor(-2.9161) tensor(3.3856)
conv2.weight tensor(-0.3681) tensor(0.4515)
conv2.bias tensor(-0.6057) tensor(0.1901)
pool2.weight tensor(-0.9111) tensor(0.6887)
pool2.bias tensor(-0.3551) tensor(0.3516)
conv3.weight tensor(-1.3570) tensor(0.9937)
conv3.bias tensor(-0.2725) tensor(0.1747)
conv4.weight tensor(-3.0408) tensor(0.5302)
conv4.bias tensor(-0.5896) tensor(0.0875)
fc.weight tensor(-2.2576) tensor(1.7449)
fc.bias tensor(-0.0514) tensor(0.0661)


In [None]:
model.load_state_dict(torch.load('/home/oscar/Desktop/Exjobb/thesis/checkpoints/hamida_et_al/Salinas/'))

In [None]:
model.train()

max_logits = np.zeros(BATCH_SIZE)

for idx, (data_x, data_u) in enumerate(zip(train_labeled_loader, train_unlabeled_loader)):
    optimizer.zero_grad()
    
    input_x, target_x = data_x
    input_u_w, input_u_s = data_u
    
    inputs = torch.cat(input_x, input_u_w, input_u_s)
    
    logits = model(input)
    max_logits[idx] = torch.max(logits)
    
    logits_l = logits[0:BATCH_SIZE]
    logits_w, logits_s = logits[BATCH_SIZE:].chunk(2)

    sup_loss = loss_labeled(logits_l, target_x)

    psuedo_label = torch.softmax(logits_w.detach_(), dim=-1)
    max_probs, psuedo_target = torch.max(psuedo_label, dim=-1)
    mask = max_probs.ge(0.95).float()

    unsup_loss = (loss_unlabeled(logits_s, psuedo_target)*mask).mean()
    
    loss = sup_loss + unsup_loss
    
    loss.backward()
    optimizer.step()