In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision

from tqdm.notebook import tqdm

In [2]:
try:
    from torchsummary import summary
except ModuleNotFoundError: 
    !pip install torchsummary
    from torchsummary import summary

In [3]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(f"TORCH DEVICE: {device}")

TORCH DEVICE: cuda


In [4]:
train_dir = '../input/gravity-spy-gravitational-waves/train/train/'
test_dir = '../input//gravity-spy-gravitational-waves/test/test/'
validation_dir = '../input//gravity-spy-gravitational-waves/validation/validation/'

In [5]:
class_names = os.listdir(train_dir)
n_classes=len(class_names)
n_classes

22

In [None]:
plt.figure(figsize=(10,50))
for fold, i in zip(class_names,range(0,len(class_names))):

    plt.subplot(11,2, i+1)
    img_read = plt.imread(train_dir+fold+'/'+os.listdir(train_dir+fold)[0])
    plt.imshow(img_read)
    plt.title(class_names[i])
    
plt.show()

In [6]:
img_shape=[64,64]
batch_size=128

In [7]:
transform = torchvision.transforms.Compose([torchvision.transforms.Resize(img_shape),
                                torchvision.transforms.ToTensor(),
                                torchvision.transforms.Grayscale(),
                                torchvision.transforms.Normalize((0.5), (0.5))])



In [8]:
train_set = torchvision.datasets.ImageFolder(train_dir, transform)
val_set = torchvision.datasets.ImageFolder(validation_dir,transform)
#test_set = torchvision.datasets.ImageFolder(test_dir,transform)

In [9]:
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)
#test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)

## Network utility

In [10]:
def train(epoch,net,optimizer,scheduler, criterion,train_loader, val_loader):
    for epoch in tqdm(range(epoch)):  
        train_loss = 0
        correct_train = 0
        total_train = 0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs.to(device))
            loss = criterion(outputs, labels.to(device))
            loss.backward()
            optimizer.step()
            scheduler.step(loss)
            train_loss += loss.item()
            _, predicted_train = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += predicted_train.eq(labels.to(device)).cpu().sum().item()
    
        train_accuracy = 100 * correct_train / total_train
        val_accuracy = evaluate_accuracy(net, val_loader)
        print('Epoch %d, train loss: %.3f, train accuracy: %.2f%%, val accuracy: %.2f%%' %
              (epoch + 1, train_loss / len(train_loader), train_accuracy, val_accuracy))
        
    print('Finished Training')
    
def evaluate_accuracy(model, dataloader):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = model(images.to(device))
            _, predicted = torch.max(outputs.data, 1)  
            total += labels.size(0) 
            correct += (predicted == labels.to(device)).sum().item() 
    accuracy = 100 * correct / total  
    return accuracy

def accuracy_classes(net,dataloader,classes): 
    correct_pred = {classname: 0 for classname in classes}
    total_pred = {classname: 0 for classname in classes}
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            outputs = net(images.to(device))
            _, predictions = torch.max(outputs, 1)
            for label, prediction in zip(labels.to(device), predictions):
                if label == prediction:
                    correct_pred[classes[label]] += 1
                total_pred[classes[label]] += 1

    for classname, correct_count in correct_pred.items():
        accuracy = 100 * float(correct_count) / total_pred[classname]
        print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

## CNN

In [11]:
class Flatten(nn.Module):
    def forward(self, input):
        batch_size = input.size(0)
        out = input.view(batch_size,-1)
        return out # (batch_size, *size)
    
class CNN(nn.Module):
    def __init__(self, img_shape, hidden_sizes,kernel_size,padding,dropout=False,p=0.05,act='relu',input_channel=1,n_classes=35):
        super(CNN, self).__init__()

        net=[]
        hidden_sizes=[input_channel]+hidden_sizes
        #print(img_shape)
        for i in range(len(hidden_sizes)-1):
            net.append(nn.Conv2d(hidden_sizes[i], hidden_sizes[i+1], kernel_size=kernel_size[i], padding=padding[i]))
            img_shape=[(img_shape[0]-kernel_size[i]+2*padding[i])/1+1,(img_shape[1]-kernel_size[i]+2*padding[i])/1+1 ]
            #print(i, img_shape)
            net.append(nn.BatchNorm2d(hidden_sizes[i+1]))
            if act=='relu':
                net.append(nn.ReLU())
            elif act=='elu':
                net.append(nn.SiLU())
            else:
                net.append(nn.Tanh())
            if dropout:
                net.append(nn.Dropout(p=p))
            net.append(nn.MaxPool2d(kernel_size=2, stride=2,padding=0))
            img_shape=[(img_shape[0]-2)//2+1,(img_shape[1]-2)//2+1]
            #print(i, img_shape)
        net.append(Flatten())
        #print(int(img_shape[0]*img_shape[1]*hidden_sizes[-1]))
        net.append(nn.Linear(int(img_shape[0]*img_shape[1]*hidden_sizes[-1]), n_classes))
        self.net=nn.ModuleList(net)

        
    def forward(self, x):
        for l in self.net:
            #print(x.shape)
            x=l(x)
        return x

In [12]:
hidden_sizes=[128,128]
kernel_size=[7,5]
padding_size=[2,1]

In [13]:
net = CNN(img_shape,hidden_sizes=hidden_sizes,kernel_size=kernel_size, padding=padding_size,n_classes=n_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0005)
scheduler= optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.92, patience=3, min_lr=1e-7)
epoch=5

In [14]:
summary(net, input_size=(1,*img_shape))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 62, 62]           6,400
       BatchNorm2d-2          [-1, 128, 62, 62]             256
              ReLU-3          [-1, 128, 62, 62]               0
         MaxPool2d-4          [-1, 128, 31, 31]               0
            Conv2d-5          [-1, 128, 29, 29]         409,728
       BatchNorm2d-6          [-1, 128, 29, 29]             256
              ReLU-7          [-1, 128, 29, 29]               0
         MaxPool2d-8          [-1, 128, 14, 14]               0
           Flatten-9                [-1, 25088]               0
           Linear-10                   [-1, 22]         551,958
Total params: 968,598
Trainable params: 968,598
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 15.05
Params size (MB): 3.69
Estimated 

In [15]:
train(epoch,net,optimizer,scheduler, criterion,train_loader,val_loader)

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1, train loss: 0.738, train accuracy: 82.05%, val accuracy: 92.00%
Epoch 2, train loss: 0.285, train accuracy: 92.75%, val accuracy: 92.08%
Epoch 3, train loss: 0.273, train accuracy: 93.18%, val accuracy: 92.23%
Epoch 4, train loss: 0.272, train accuracy: 92.99%, val accuracy: 92.38%
Epoch 5, train loss: 0.273, train accuracy: 93.13%, val accuracy: 92.15%
Finished Training


In [16]:
accuracy_classes(net,val_loader,class_names)   

Accuracy for class: Repeating_Blips is 95.9 %
Accuracy for class: Power_Line is 96.2 %
Accuracy for class: Paired_Doves is 96.9 %
Accuracy for class: Tomte is 98.3 %
Accuracy for class: Low_Frequency_Lines is 80.6 %
Accuracy for class: 1400Ripples is 95.1 %
Accuracy for class: Low_Frequency_Burst is 99.4 %
Accuracy for class: Air_Compressor is 92.9 %
Accuracy for class: Wandering_Line is 90.8 %
Accuracy for class: Chirp is 90.2 %
Accuracy for class: Violin_Mode is 87.5 %
Accuracy for class: No_Glitch is 61.4 %
Accuracy for class: Light_Modulation is 46.2 %
Accuracy for class: Helix is 12.5 %
Accuracy for class: 1080Lines is 100.0 %
Accuracy for class: Scattered_Light is 79.3 %
Accuracy for class: Koi_Fish is 96.7 %
Accuracy for class: Scratchy is 98.5 %
Accuracy for class: None_of_the_Above is 80.9 %
Accuracy for class: Extremely_Loud is 94.9 %
Accuracy for class: Blip  is 45.8 %
Accuracy for class: Whistle is 81.0 %
