In [1]:
from __future__ import print_function

import numpy as np
from torch.utils import data
import argparse
import torch
import torch.utils.data as data_utils
import torch.optim as optim
from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import roc_auc_score

seed =77
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

def seed_worker(worker_id):
    worker_seed = 11
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)



class Dataset(data.Dataset):
    def __init__(self, X1,Y):
        self.X1 = X1
        self.Y = Y
    def __len__(self):        
        return len(self.X1)
    
    def __getitem__(self, index):
        x = self.X1[index]
        y = self.Y[index]
        return x,y



## Negetive pictures

In [2]:
xTrain=np.load("../extracted_features/FS-change3/input_features_free_train.npy")

yTrain=np.load("../extracted_features/FS-change3/label_free_train.npy")

xTest=np.load("../extracted_features/FS-change3/input_features_free_test.npy")

yTest=np.load("../extracted_features/FS-change3/label_free_test.npy")

xVal=np.load("../extracted_features/FS-change3/input_features_free_val.npy")

yVal=np.load("../extracted_features/FS-change3/label_free_val.npy")


In [3]:
xTrain.shape

(1121, 19, 8)

In [4]:
xVal.shape

(199, 19, 8)

In [5]:
xTest.shape

(190, 19, 8)

In [6]:
xTrain = xTrain.reshape(len(xTrain),1,19,8)
xTest = xTest.reshape(len(xTest),1,19,8)
xVal = xVal.reshape(len(xVal),1,19,8)

In [7]:
print(yTest)

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [8]:
print(xTrain.shape,yTrain.shape)
print(xTest.shape,yTest.shape)
print(xVal.shape,yVal.shape)

(1121, 1, 19, 8) (1121,)
(190, 1, 19, 8) (190,)
(199, 1, 19, 8) (199,)


In [9]:
traindataset = Dataset(xTrain,yTrain)
testdataset = Dataset(xTest,yTest)
valdataset = Dataset(xVal,yVal)

In [10]:
class Stammering(data_utils.Dataset):
    def __init__(self, target_number=1, mean_bag_length=5, var_bag_length=2, num_bag=150, seed=2021, train="train"):
        self.target_number = target_number
        self.mean_bag_length = mean_bag_length
        self.var_bag_length = var_bag_length
        self.num_bag = num_bag
        self.train = train
        self.r = np.random.RandomState(seed)

        if self.train=="train":
            self.train_bags_list, self.train_labels_list = self._create_bags()
        elif self.train=="val":
            self.val_bags_list, self.val_labels_list = self._create_bags()
        else:
            self.test_bags_list, self.test_labels_list = self._create_bags()

    def _create_bags(self):
        if self.train=="train":
            print("train")
            loader = data_utils.DataLoader(traindataset,
                                           batch_size=1,
                                           shuffle=True)
        elif self.train=="val":
            print("val")
            loader = data_utils.DataLoader(valdataset,
                                           batch_size=1,
                                           shuffle=True)
        else:
            loader = data_utils.DataLoader(testdataset,
                                           batch_size=1,
                                           shuffle=True)
            
        bags_list = []
        labels_list = []
        for (batch_data, batch_labels) in loader:

            bags_list.append(batch_data.reshape(19,1,8))
            temp = torch.as_tensor(np.array([batch_labels for x in range(19)]))
            labels_list.append(temp)
            
               

        return bags_list, labels_list

    def __len__(self):
        if self.train=="train":
            return len(self.train_labels_list)
        elif self.train=="val":
            return len(self.val_labels_list)
        else:
            return len(self.test_labels_list)

    def __getitem__(self, index):
        if self.train=="train":
            bag = self.train_bags_list[index]
            label = [max(self.train_labels_list[index]), self.train_labels_list[index]]
        elif self.train=="val":
            bag = self.val_bags_list[index]
            label = [max(self.val_labels_list[index]), self.val_labels_list[index]]
        else:
            bag = self.test_bags_list[index]
            label = [max(self.test_labels_list[index]), self.test_labels_list[index]]

        return bag, label

In [11]:
train_loader = data_utils.DataLoader(Stammering(train="train"),num_workers=0,worker_init_fn=seed_worker,batch_size=1,shuffle=True)
test_loader = data_utils.DataLoader(Stammering(train="val"),num_workers=0,worker_init_fn=seed_worker,batch_size=1,shuffle=True)
val_loader = data_utils.DataLoader(Stammering(train="test"),num_workers=0,worker_init_fn=seed_worker,batch_size=1,shuffle=True)

train
val


In [12]:
def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    if type(m) == nn.Conv1d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
        
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.L = 512
        self.D = 128
        self.K = 1

        self.feature_extractor_part1 = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=2),
            nn.ReLU(),
            nn.Conv1d(64, 32, kernel_size=2),
            nn.ReLU(),
            nn.Linear(6, 128),
            nn.ReLU(),
            nn.Linear(128,16),
            nn.ReLU()
        )

        self.feature_extractor_part2 = nn.Sequential(
            nn.BatchNorm1d(512, affine=False),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024,self.L),
            nn.ReLU()
        )

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )

        self.classifier = nn.Sequential(

            nn.Linear(self.L*self.K, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x=x.squeeze(0)

        H = self.feature_extractor_part1(x)
        
        H = H.view(H.size(0), -1)

        H = self.feature_extractor_part2(H)  

        A = self.attention(H)  

        A = torch.transpose(A, 1, 0) 

        A = F.softmax(A, dim=1)  

        M = torch.mm(A, H)  

        Y_prob = self.classifier(M)
        Y_hat = torch.ge(Y_prob, 0.5).float()

        return Y_prob




    # AUXILIARY METHODS
    def calculate_classification_error(self, X, Y):
        Y = Y.float()
        Y_prob= self.forward(X)
        Y_hat = torch.ge(Y_prob, 0.5).float()
        error = 1. - Y_hat.eq(Y).cpu().float().mean().item()

        return error, Y_hat

    def calculate_objective(self, X, Y):
        Y = Y.float()
        Y_prob = self.forward(X)
        Y_prob = torch.clamp(Y_prob, min=1e-5, max=1. - 1e-5)
        neg_log_likelihood = -1. * (Y * torch.log(Y_prob) + (1. - Y) * torch.log(1. - Y_prob))  # negative log bernoulli

        return neg_log_likelihood

In [13]:
# from torchsummary import summary
# model=Attention()
# summary(model.cuda(),(1,8))

In [14]:
LR = 1e-4
model=Attention()
if torch.cuda.is_available():model.cuda()
optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.9, 0.999))

best_model = "./saved_models/best_model_Attention-mil-fs-grp3_change"
train_acc = []
modelname=[]
val_acc = []
best_acc = 0
for epoch in range(0, 50): #12345-24,
    model.train()
    train_loss = 0.
    train_error = 0.
    y =[]
    ypred = []
    for batch_idx, (data, label) in enumerate(train_loader):
        bag_label = label[0]
        data, bag_label = data.cuda(), bag_label.cuda()
        data, bag_label = Variable(data), Variable(bag_label)
        optimizer.zero_grad()
        loss = model.calculate_objective(data.float(), bag_label.float())
        train_loss += loss.data[0]
        error,y_pred = model.calculate_classification_error(data.float(), bag_label.float())
        train_error += error
        y_pred=y_pred.squeeze(1)
        ypred.extend(y_pred.tolist())
        y.extend(bag_label.tolist())
        loss.mean().backward()
        optimizer.step()

    
    trainacc = accuracy_score(y,ypred)
    train_loss /= len(train_loader)
    train_error /= len(train_loader)
    print("EPOCH ",epoch)
    print('Train : Loss: {:.4f}, Train error: {:.4f}, Train f1 : {}'.format(train_loss.cpu().numpy()[0], 
                                                                                train_error,trainacc))
    y =[]
    ypred = []
    val_error=0.
    val_loss= 0.
    model.eval()
    with torch.no_grad():
        for batch_idx, (data, label) in enumerate(val_loader):
            bag_label = label[0]
            data, bag_label = data.cuda(), bag_label.cuda()
            data, bag_label = Variable(data), Variable(bag_label)
            loss = model.calculate_objective(data.float(), bag_label.float())
            val_loss += loss.data[0]
            error, y_pred = model.calculate_classification_error(data.float(), bag_label.float())
            val_error += error
            y_pred=y_pred.squeeze(1)
            ypred.extend(y_pred.tolist())
            y.extend(bag_label.tolist())
        valacc = accuracy_score(y,ypred)
        val_loss /= len(val_loader)
        val_error /= len(val_loader)
        print('Val : Loss: {:.4f}, val error: {:.4f}, Val f1 :{}'.format(val_loss.cpu().numpy()[0], val_error,valacc))
        if valacc>=best_acc:
            print("---------State saved---------")
            best_acc = valacc
            best_state=model.state_dict()
            torch.save(best_state, best_model+'_epoch_'+str(epoch)+".pth")
            modelname.append(best_model+'_epoch_'+str(epoch)+".pth")
        print('Best validation accuracy ',best_acc)

        

EPOCH  0
Train : Loss: 0.6780, Train error: 0.4389, Train f1 : 0.5611061552185549
Val : Loss: 0.7318, val error: 0.6158, Val f1 :0.38421052631578945
---------State saved---------
Best validation accuracy  0.38421052631578945
EPOCH  1
Train : Loss: 0.6341, Train error: 0.3782, Train f1 : 0.6217662801070473
Val : Loss: 0.7542, val error: 0.5684, Val f1 :0.43157894736842106
---------State saved---------
Best validation accuracy  0.43157894736842106
EPOCH  2
Train : Loss: 0.5982, Train error: 0.3238, Train f1 : 0.6761819803746655
Val : Loss: 0.8564, val error: 0.5053, Val f1 :0.49473684210526314
---------State saved---------
Best validation accuracy  0.49473684210526314
EPOCH  3
Train : Loss: 0.5589, Train error: 0.2890, Train f1 : 0.7109723461195361
Val : Loss: 0.8440, val error: 0.5053, Val f1 :0.49473684210526314
---------State saved---------
Best validation accuracy  0.49473684210526314
EPOCH  4
Train : Loss: 0.5433, Train error: 0.2881, Train f1 : 0.711864406779661
Val : Loss: 0.9460,

EPOCH  41
Train : Loss: 0.1872, Train error: 0.0785, Train f1 : 0.9214986619090099
Val : Loss: 1.5742, val error: 0.3474, Val f1 :0.6526315789473685
Best validation accuracy  0.7052631578947368
EPOCH  42
Train : Loss: 0.1346, Train error: 0.0491, Train f1 : 0.9509366636931311
Val : Loss: 2.2045, val error: 0.3737, Val f1 :0.6263157894736842
Best validation accuracy  0.7052631578947368
EPOCH  43
Train : Loss: 0.1725, Train error: 0.0758, Train f1 : 0.9241748438893844
Val : Loss: 2.1079, val error: 0.4474, Val f1 :0.5526315789473685
Best validation accuracy  0.7052631578947368
EPOCH  44
Train : Loss: 0.1726, Train error: 0.0660, Train f1 : 0.9339875111507583
Val : Loss: 2.1460, val error: 0.4000, Val f1 :0.6
Best validation accuracy  0.7052631578947368
EPOCH  45
Train : Loss: 0.1255, Train error: 0.0491, Train f1 : 0.9509366636931311
Val : Loss: 2.4455, val error: 0.3421, Val f1 :0.6578947368421053
Best validation accuracy  0.7052631578947368
EPOCH  46
Train : Loss: 0.1400, Train error: 

In [15]:
device = torch.device("cuda")
model=Attention()
# best_state=torch.load("../best-models/Attention_MIL_best_free_speech_F1_73.pth")
best_state=torch.load(modelname[-1])
model.load_state_dict(best_state)
model.to(device)
model.eval()
test_loss = 0.
correct = 0
total = 0
y =[]
ypred = []
ypred1=[]
test_error = 0.
for batch_idx, (data, label) in enumerate(test_loader):
    bag_label = label[0]
    instance_labels = label[1]
    data, bag_label = data.cuda(), bag_label.cuda()
    data, bag_label = Variable(data), Variable(bag_label)
    y_pred = model(data.float())
    error, predicted_label = model.calculate_classification_error(data.float(), bag_label.float())
    test_error += error
    predicted_label=predicted_label.squeeze(1)
    y_pred=y_pred.squeeze(1)
    ypred.extend(predicted_label.tolist())
    ypred1.extend(y_pred.tolist())
    y.extend(bag_label.tolist())
    
acc=accuracy_score(y,ypred)
# print(y,ypred)
tn, fp, fn, tp = confusion_matrix(y,ypred).ravel()
f1score=f1_score(y,ypred)
precision=precision_score(y,ypred)
recall=recall_score(y,ypred)
roc=roc_auc_score(y,ypred)
specificity=tn/(tn+fp)
print(acc,f1score,precision,recall,roc,specificity)    

0.7638190954773869 0.7374301675977653 0.7415730337078652 0.7333333333333333 0.7611620795107034 0.7889908256880734


In [16]:
# for i in range(len(modelname)):
#     device = torch.device("cuda")
#     model=Attention()
#     # best_state=torch.load("../best-models/Attention_MIL_best_free_speech_F1_73.pth")
#     best_state=torch.load(modelname[i])
#     model.load_state_dict(best_state)
#     model.to(device)
#     model.eval()
#     test_loss = 0.
#     correct = 0
#     total = 0
#     y =[]
#     ypred = []
#     ypred1=[]
#     test_error = 0.
#     for batch_idx, (data, label) in enumerate(test_loader):
#         bag_label = label[0]
#         instance_labels = label[1]
#         data, bag_label = data.cuda(), bag_label.cuda()
#         data, bag_label = Variable(data), Variable(bag_label)
#         y_pred = model(data.float())
#         error, predicted_label = model.calculate_classification_error(data.float(), bag_label.float())
#         test_error += error
#         predicted_label=predicted_label.squeeze(1)
#         y_pred=y_pred.squeeze(1)
#         ypred.extend(predicted_label.tolist())
#         ypred1.extend(y_pred.tolist())
#         y.extend(bag_label.tolist())

#     acc=accuracy_score(y,ypred)
#     # print(y,ypred)
#     tn, fp, fn, tp = confusion_matrix(y,ypred).ravel()
#     f1score=f1_score(y,ypred)
#     precision=precision_score(y,ypred)
#     recall=recall_score(y,ypred)
#     roc=roc_auc_score(y,ypred)
#     specificity=tn/(tn+fp)
#     print(acc,f1score,precision,recall,roc,specificity)    

0.49748743718592964 0.626865671641791 0.47191011235955055 0.9333333333333333 0.5354740061162079 0.13761467889908258
0.5879396984924623 0.6434782608695652 0.5285714285714286 0.8222222222222222 0.6083588175331295 0.3944954128440367
0.6532663316582915 0.6567164179104478 0.5945945945945946 0.7333333333333333 0.6602446483180429 0.5871559633027523
0.5829145728643216 0.29059829059829057 0.6296296296296297 0.18888888888888888 0.548572884811417 0.908256880733945
0.4824120603015075 0.5960784313725491 0.46060606060606063 0.8444444444444444 0.5139653414882772 0.1834862385321101
0.457286432160804 0.2894736842105263 0.3548387096774194 0.24444444444444444 0.43873598369011213 0.6330275229357798
0.5226130653266332 0.5128205128205129 0.47619047619047616 0.5555555555555556 0.5254841997961264 0.4954128440366973
0.5326633165829145 0.42944785276073616 0.4794520547945205 0.3888888888888889 0.5201325178389399 0.6513761467889908
0.5376884422110553 0.5964912280701755 0.4927536231884058 0.7555555555555555 0.5566