In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import numpy as np
import hexagdly

In [38]:
class hex_model(nn.Module):
    def __init__(self, nin, nout):          
        super(hex_model, self).__init__()
        self.name = 'hex_model'
        self.hexconv_1 = hexagdly.Conv2d(in_channels = nin, out_channels = 4, \
                                         kernel_size = 1, stride = 1, bias=True)
        self.hexpool_1 = hexagdly.MaxPool2d(kernel_size = 1, stride = 2)
        self.hexconv_2 = hexagdly.Conv2d(4, 8, 2, 1, bias=True)
        self.hexpool_2 = hexagdl y.MaxPool2d(kernel_size = 2, stride = 2)
        self.bn1 = nn.BatchNorm2d(4)
        self.bn2 = nn.BatchNorm2d(8)
        
        self.dropout = nn.Dropout(0.5)

        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 64)
        self.fc3 = nn.Linear(64, nout)

    def forward(self, x):
        # Applying hexagonal convolutional and pooling layers
        print("Initial Size:", x.size())
        x = self.hexconv_1(x)
        print("After conv1:", x.size())
        x = F.relu(x)
        x = F.relu(self.bn1(x))
        print("After bn1:", x.size())
        x = self.hexpool_1(x)
        print("After maxpool1:", x.size())
        
        x = self.hexconv_2(x)
        print("After conv2:", x.size())
        x = F.relu(x)
        x = F.relu(self.bn2(x))
        print("After bn2:", x.size())
        # x = self.hexpool_2(x)
        
        # Flatten the output for the fully connected layers
        # x = torch.flatten(x, 1)
        x = x.view(-1, 512)
        print("Before fc1:", x.size())
        
        # Applying fully connected layers with dropout and activation functions
        x = F.relu(self.fc1(x))
        print("After fc1:", x.size())
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        print("After fc2:", x.size())
        x = self.dropout(x)
        x = self.fc3(x)  # Output layer, no activation here
        print("After fc3:", x.size())
        
        return x



# class fc(nn.Module):
#     def __init__(self, in_features, out_features):
#         super(fc, self).__init__()
#         self.ops = nn.Sequential(nn.Linear(in_features, out_features),
#                                  nn.ReLU())
#     def forward(self, input):
#         return self.ops(input)


# class hex_model(nn.Module):
#     def __init__(self, nout, plot=False):
#         super(hex_model, self).__init__()
#         self.name = 'hex_model'
#         self.plot=plot
#         self.hexconv_1 = hexagdly.Conv2d(in_channels = 1, out_channels = 4, \
#                                          kernel_size = 1, stride = 1, bias=True)
#         self.hexpool = hexagdly.MaxPool2d(kernel_size = 1, stride = 2)
#         self.hexconv_2 = hexagdly.Conv2d(4, 8, 1, 1, bias=True)

#         self.fc = nn.Sequential(fc(512, 128),
#                                 nn.Dropout(0.5),
#                                 fc(128, 64),
#                                 nn.Linear(64, nout))

#     def forward(self, x, ni=0):
#         x = self.hexconv_1(x)
        
#         # Option to plot first conv layer output
#         if self.plot:
#             plot_hextensor(x, image_range=(ni,ni+1), channel_range=(0,None), figname='after_first_hexconv')
            
#         x = F.relu(x)
#         x = self.hexpool(x)
#         x = self.hexconv_2(x)
        
#         # Option to plot second conv layer output
#         if self.plot:
#             plot_hextensor(x, image_range=(ni,ni+1), channel_range=(0,None), figname='after_second_hexconv')
            
#         x = F.relu(x)
#         x = x.view(-1, 512)
#         x = self.fc(x)
#         return x

In [23]:
import importlib
import numpy as np
import matplotlib.pyplot as plt
from hexagdly_tools import plot_hextensor
from example_utils import toy_data, toy_dataset, model

In [24]:
shape_list = ['snowflake_2', 'snowflake_3', 'snowflake_4', 'double_hex']
val_data = toy_dataset(shape_list, 32)
val_data.create()
val_dataloader = val_data.to_dataloader()
train_data = toy_dataset(shape_list, 128)
train_data.create()
train_dataloader = train_data.to_dataloader()

In [37]:
net = hex_model(1,len(shape_list))
epochs = 10
cnn_model = model(train_dataloader, val_dataloader, net, epochs=epochs)
cnn_model.train()
cnn_model.save_current()

Epoch 1
Initial Size: torch.Size([8, 1, 16, 16])
After conv1: torch.Size([8, 4, 16, 16])
After bn1: torch.Size([8, 4, 16, 16])
After maxpool1: torch.Size([8, 4, 8, 8])
After conv2: torch.Size([8, 8, 8, 8])
After bn2: torch.Size([8, 8, 8, 8])
Before fc1: torch.Size([8, 512])
After fc1: torch.Size([8, 256])
After fc2: torch.Size([8, 64])
After fc3: torch.Size([8, 4])
Initial Size: torch.Size([8, 1, 16, 16])
After conv1: torch.Size([8, 4, 16, 16])
After bn1: torch.Size([8, 4, 16, 16])
After maxpool1: torch.Size([8, 4, 8, 8])
After conv2: torch.Size([8, 8, 8, 8])
After bn2: torch.Size([8, 8, 8, 8])
Before fc1: torch.Size([8, 512])
After fc1: torch.Size([8, 256])
After fc2: torch.Size([8, 64])
After fc3: torch.Size([8, 4])
Initial Size: torch.Size([8, 1, 16, 16])
After conv1: torch.Size([8, 4, 16, 16])
After bn1: torch.Size([8, 4, 16, 16])
After maxpool1: torch.Size([8, 4, 8, 8])
After conv2: torch.Size([8, 8, 8, 8])
After bn2: torch.Size([8, 8, 8, 8])
Before fc1: torch.Size([8, 512])
After

KeyboardInterrupt: 

In [None]:
def train_model(model, train_loader, optimizer, criterion, epoch):
    model.train()
    train_loss = 0
    for input, target in tqdm(train_loader, total=len(train_loader)):
        optimizer.zero_grad()
        output, _ = model(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)
    print('[Training set] Epoch: {:d}, Average loss: {:.4f}'.format(epoch + 1, train_loss))
   