In [None]:
import os
from google.colab import drive

drive.mount('/content/drive')
os.chdir('/content/drive/My Drive/scattering_network')


import sys
sys.path.append('/content/drive/My Drive/scattering_network ')

In [None]:
from sklearn import metrics
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torchsummary import summary
from torchvision import datasets, transforms
from kymatio.torch import Scattering2D
import kymatio.datasets as scattering_datasets
import argparse
from kymatio.numpy import Scattering2D
# Set the parameters of the scattering transform.

species = ['Arthonia_radiata','Caloplaca_cerina','Candelariella_reflexa','Candelariella_xanthostigma','Chrysothrix_candelaris','Flavoparmelia_caperata','Gyalolechia_flavorubescens','Hyperphyscia_adglutinata'
        ,'Lecanora_argentata','Lecanora_chlarotera','Lecidella_elaeochroma','Melanelixia_glabratula'
        ,'Phaeophyscia_orbicularis','Physcia_biziana','Physconia_grisea','Ramalina_farinacea','Ramalina_fastigiata','Xanthomendoza_fallax','Xanthomendoza_fulva','flavoparmenia_soredians']



def calculate_and_plot_precision_recall(tst_lab, pred, species):
    precision, recall, fbeta, support = precision_recall_fscore_support(tst_lab, pred)  
    df = pd.DataFrame({"X":species, "precision":precision,"recall":recall,'f1score': fbeta})
    df.plot(x="X", y=["precision", "recall",'f1score'], kind="bar")
    plt.tight_layout()
    return np.mean(precision), np.mean(recall), np.mean(fbeta)


def plot_confusion_matrix(cm, classes, acc,normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    cm = cm.cpu().numpy()
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    titolo = title + '. Mean accuracy: ' + str(acc)
    plt.title(titolo)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if cm[i, j] > 0.10:
            plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    all_preds = all_preds.to(device)
    for batch in loader:
        images, labels = batch 
        images = images.to(device)
        labels = labels.to(device)       
        output = model(scattering(images))
        _, preds = torch.max(output, 1)
        all_preds = torch.cat((all_preds,preds),0)
    return all_preds 


def calculate_accuracy(pred, true_lab):
    total = pred.shape[0]
    cont = 0
    for i in range(pred.shape[0]):
        if pred[i] == true_lab[i]:
            cont = cont + 1
        else:
            continue 
    
    return cont/total,cont 


In [None]:
!pip install kymatio

In [None]:

J = 3
M, N = 100, 100

# Generate a sample signal.
x = np.random.randn(3,M, N)

# Define a Scattering2D object.
S = Scattering2D(J=2,L = 4, shape=(100, 100))



# Equivalently, use the alias.
Sx = S(x)
print(Sx.shape)
print(Sx.shape[0]*Sx.shape[1])

In [None]:


class Scattering2dCNN(nn.Module):
    '''
        Simple CNN with 3x3 convs based on VGG
    '''
    def __init__(self, in_channels,dim_im, classifier_type='cnn'):
        super(Scattering2dCNN, self).__init__()
        self.in_channels = in_channels
        self.dim_im = dim_im
        self.classifier_type = classifier_type
        self.build()

    def build(self):
        cfg = [256, 'M', 512,512]
        layers = []
        self.K = self.in_channels
        self.bn = nn.BatchNorm2d(self.K)
        if self.classifier_type == 'cnn':
            for v in cfg:
                if v == 'M':
                    layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                else:
                    conv2d = nn.Conv2d(self.in_channels, v, kernel_size=3, padding=1)
                    layers += [ conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                    self.in_channels = v

            layers += [nn.AdaptiveAvgPool2d(2)]
            self.features = nn.Sequential(*layers)
            self.classifier =  nn.Linear(512*4, 20)
            self.softmax = nn.LogSoftmax(dim=1)

        elif self.classifier_type == 'mlp':
            self.classifier = nn.Sequential(
                        nn.Linear(self.K*8*8, 512), nn.ReLU(),
                        nn.Linear(1024, 1024), nn.ReLU(),
                        nn.Linear(1024, 20))
            self.features = None

        elif self.classifier_type == 'linear':
            self.classifier = nn.Linear(self.K*25*25,20)
            self.features = None


    def forward(self, x):
        x = self.bn(x.view(-1, self.K, self.dim_im, self.dim_im))
        if self.features:
            x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        #x = self.softmax(x)
        return x




def train(model, device, train_loader, optimizer, epoch, scattering):
    model.train()
    total_loss = []
    running_corrects = []
    cont = 0
    for batch_idx, (data,target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(scattering(data))
        loss = F.cross_entropy(output, target)
        total_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        _,preds =torch.max(output, 1)
        running_corrects.append(torch.sum(preds == target.data).item()/32)


        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader), loss.item()))

    total_loss = np.array(total_loss)
    running_corrects = np.array(running_corrects)
    return np.mean(total_loss), np.mean(running_corrects)*100




def test(model, device, test_loader, scattering):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(scattering(data))
            test_loss += F.cross_entropy(output, target, reduction = 'sum').item()
            pred = output.max(1, keepdim = True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
          test_loss, correct, len(test_loader.dataset),
          100. * correct / len(test_loader.dataset)))
    
    return test_loss, correct, 100. * correct / len(test_loader.dataset)

In [None]:



transform_train = transforms.Compose([transforms.Resize((100,100)),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.RandomRotation(10),
                                      transforms.RandomAffine(0, shear = 10, scale = (0.6,1.1)),
                                      transforms.ColorJitter(brightness = 0.2,contrast = 0.2,saturation = 0.2),
                                      transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])


transform = transforms.Compose([transforms.Resize((100,100)),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

#training_dataset = datasets.CIFAR10(root = '/data',train = True, download = True, transform = transform_train)
#validation_dataset = datasets.CIFAR10(root = '/data',train = True, download = True, transform = transform)

data_dir = '../dataset/data/train'
data_dir_val = '../dataset/data/valid'
training_dataset = datasets.ImageFolder(data_dir, transform=transform_train) # TODO: create the Ima
validation_dataset = datasets.ImageFolder(data_dir_val, transform = transform)


train_loader = torch.utils.data.DataLoader(training_dataset, batch_size =32, shuffle = True)
test_loader = torch.utils.data.DataLoader(validation_dataset, batch_size = 32, shuffle = False)

In [None]:
use_cuda = torch.cuda.is_available()

device = torch.device("cuda" if use_cuda else "cpu")
scattering = Scattering2D(J=2,L =4, shape=(100, 100))
scattering = scattering.cuda()
K = 75
model = Scattering2dCNN(K,25,'cnn').to(device)

In [None]:
summary(model, (K, 25, 25))

In [None]:
!pip install hiddenlayer

In [None]:
early_stop = False 
stopping_count = 0
lr = 0.01 
max_acc = 0
tr_loss_tl = []
tr_acc_tl = []
tst_loss_tl = []
tst_acc_tl = []
prev_acc = 0
for epoch in range(0,100):
    #print('------------------------- ',epoch,' ----------------------------')
    if epoch%20 ==0:
        optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum =0.9,weight_decay = 0.0005)
        
        lr *=0.2
    
    tr_ls,tr_ac = train(model, device, train_loader, optimizer, epoch +1, scattering)
    ts_ls , _ , ts_acc = test(model, device, test_loader, scattering)
    if prev_acc > ts_acc:
        stopping_count = stopping_count + 1 
    prev_acc = ts_acc  
    if ts_acc > max_acc:
        max_acc = ts_acc

    tr_loss_tl.append(tr_ls)
    tr_acc_tl.append(tr_ac.item())
    tst_loss_tl.append(ts_ls)
    tst_acc_tl.append(ts_acc)

    if stopping_count > 300:
        print("EARLY STOPPING PERFORMED AT EPOCH ", epoch)
        break 

In [None]:
with torch.no_grad():
    prediction_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=100000)
    test_preds = get_all_preds(model, prediction_loader)


c = test_preds.cpu().numpy()
true_labels = validation_dataset.targets

accuracy, corrects = calculate_accuracy(c, true_labels)
print(accuracy)
print(corrects)
print(max_acc)

In [None]:
def plot_and_save_image(running_corrects,val_running_corrects, running_loss,val_running_loss, title, direc):
    fig, (ax1, ax2) = plt.subplots(1, 2,figsize=(15,5))
    fig.suptitle(title)
    for i in range(len(running_corrects)):
        running_corrects[i]/=100
        val_running_corrects[i]/=100
    ax1.plot(running_corrects)
    ax1.plot(val_running_corrects)
    ax1.set_title('model accuracy')

    ax1.legend(['train', 'test'], loc='upper left')
    ax2.plot(running_loss)
    ax2.plot(val_running_loss)
    ax2.set_title('model loss')

    ax2.legend(['train', 'test'], loc='upper left')
    ylab = ['accuracy','loss']
    ii = 0
    for ax in [ax1,ax2]:
        ax.set(xlabel='epochs', ylabel=ylab[ii])
        ii = ii + 1
    ax1.grid()
    ax2.grid()
    plt.savefig(direc)  

In [None]:
plot_and_save_image(tr_acc_tl,tst_acc_tl, tr_loss_tl,tst_loss_tl, 'wave_shortcnn_lichens', 'results/4_4wave_shortcnn_lichens.jpg')

In [None]:

classes = os.listdir('../dataset/data/train')
def im_convert(tensor):
  image = tensor.cpu().clone().detach().numpy()
  image = image.transpose(1, 2, 0)
  image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
  image = image.clip(0, 1)
  return image



dataiter = iter(train_loader)
images, labels = dataiter.next()
fig = plt.figure(figsize=(25, 4))

for idx in np.arange(10):
  ax = fig.add_subplot(2, 5, idx+1, xticks=[], yticks=[])
  plt.imshow(im_convert(images[idx]))
  ax.set_title(classes[labels[idx].item()])
plt.savefig('results/dataset_with_augmentation.jpg')

In [None]:
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size = 32, shuffle = True)
dataiter = iter(validation_loader)
images, labels = dataiter.next()
images = images.to(device)
labels = labels.to(device)
output = model(scattering(images))
_, preds = torch.max(output, 1)

fig = plt.figure(figsize=(50, 10))

for idx in np.arange(8):
  ax = fig.add_subplot(2, 4, idx+1, xticks=[], yticks=[])
  plt.imshow(im_convert(images[idx]))
  ax.set_title("{} ({})".format(str(classes[preds[idx].item()]), str(classes[labels[idx].item()])), color=("green" if preds[idx]==labels[idx] else "red"))
plt.savefig('results/evaluate_cnnshort_6.jpg')

# Create confusion matrix

In [None]:
def calculate_and_plot_precision_recall(tst_lab, pred, species):
    precision, recall, fbeta, support = precision_recall_fscore_support(tst_lab, pred)  
    df = pd.DataFrame({"X":species, "precision":precision,"recall":recall,'f1score': fbeta})
    df.plot(x="X", y=["precision", "recall",'f1score'], kind="bar")
    plt.tight_layout()
    return np.mean(precision), np.mean(recall), np.mean(fbeta)


def plot_confusion_matrix(cm, classes, acc,normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    cm = cm.cpu().numpy()
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    titolo = title + '. Mean accuracy: ' + str(acc)
    plt.title(titolo)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=90)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if cm[i, j] > 0.10:
            plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')

@torch.no_grad()
def get_all_preds(model, loader):
    all_preds = torch.tensor([])
    all_preds = all_preds.to(device)
    for batch in loader:
        images, labels = batch 
        images = images.to(device)
        labels = labels.to(device)       
        output = model(scattering(images))
        _, preds = torch.max(output, 1)
        all_preds = torch.cat((all_preds,preds),0)
    return all_preds 

with torch.no_grad():
    prediction_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=100000)
    test_preds = get_all_preds(model, prediction_loader)

def calculate_accuracy(pred, true_lab):
    total = pred.shape[0]
    cont = 0
    for i in range(pred.shape[0]):
        if pred[i] == true_lab[i]:
            cont = cont + 1
        else:
            continue 
    
    return cont/total,cont 


c = test_preds.cpu().numpy()
true_labels = validation_dataset.targets

accuracy, corrects = calculate_accuracy(c, true_labels)
print(accuracy)
print(corrects)

stacked = torch.stack((torch.Tensor(validation_dataset.targets),test_preds.cpu()),dim=1)
cmt = torch.zeros(20,20, dtype=torch.int64)
for p in stacked:
    tl, pl = p.tolist()
    cmt[int(tl), int(pl)] = cmt[int(tl), int(pl)] + 1


plt.figure(figsize=(11,10))
plot_confusion_matrix(cmt,validation_dataset.classes,accuracy,normalize=True)
plt.savefig('conf_matrix_scat_cnn.jpg')

import pandas as pd

a,b,c = calculate_and_plot_precision_recall(list(validation_dataset.targets), list(test_preds.cpu().numpy()), species)