Skip to content

Commit

Permalink
Improved fusion model by simplifying model
Browse files Browse the repository at this point in the history
  • Loading branch information
CodingTil committed Oct 30, 2023
1 parent 61c9543 commit ed2a176
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 19 deletions.
54 changes: 35 additions & 19 deletions eiuie/fusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,42 @@
CHECKPOINT_DIRECTORY = "data/checkpoints"


class ChannelNet(nn.Module):
"""Single layer perceptron for individual channels."""

def __init__(self, input_size=4, output_size=1):
super(ChannelNet, self).__init__()
self.fc = nn.Linear(input_size, output_size)

def forward(self, x):
return self.fc(x)


class FusionNet(nn.Module):
def __init__(self, dropout_rate=0.5):
super(FusionNet, self).__init__()
"""Unifying model for all channels."""

self.model = nn.Sequential(
nn.Linear(12, 12),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(12, 9),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(9, 6),
nn.ReLU(),
nn.Dropout(dropout_rate),
nn.Linear(6, 3),
nn.ReLU(),
nn.Linear(3, 3),
nn.Sigmoid(),
)
def __init__(self):
super(FusionNet, self).__init__()
self.h_net = ChannelNet()
self.s_net = ChannelNet()
self.i_net = ChannelNet()

def forward(self, x):
return self.model(x)
# Flatten the middle dimensions
x = x.view(-1, 12) # This will reshape the input to (batch_size, 12)

# Splitting the input for the three channels
h_channel = x[:, 0::3] # Every third value starting from index 0
s_channel = x[:, 1::3] # Every third value starting from index 1
i_channel = x[:, 2::3] # Every third value starting from index 2

# Getting the outputs
h_out = self.h_net(h_channel)
s_out = self.s_net(s_channel)
i_out = self.i_net(i_channel)

# Concatenate the outputs to get the final output
return torch.cat((h_out, s_out, i_out), dim=1)


class EarlyStopping:
Expand Down Expand Up @@ -142,7 +156,9 @@ def save_checkpoint(self, epoch: int, checkpoint_path: str):
)

def load_checkpoint(self, checkpoint_path: str):
checkpoint = torch.load(f"{CHECKPOINT_DIRECTORY}/{checkpoint_path}")
checkpoint = torch.load(
f"{CHECKPOINT_DIRECTORY}/{checkpoint_path}", map_location=self.device
)
self.net.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.start_epoch = checkpoint["epoch"]
Expand Down
6 changes: 6 additions & 0 deletions eiuie/pixel_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def __init__(self, batch_size=1, chunk_size=10000, use_fraction=1.0):

self.data_array[start_idx:end_idx] = np.concatenate(hsi_data_list, axis=1)

# Shuffle data_array
np.random.shuffle(self.data_array)

self.batch_size = batch_size

def __len__(self) -> int:
Expand All @@ -49,6 +52,9 @@ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:

batch_data = self.data_array[start:end]

# Shuffle batch_data
np.random.shuffle(batch_data)

inputs = torch.tensor(batch_data[:, :12], dtype=torch.float32)
outputs = torch.tensor(batch_data[:, 12:], dtype=torch.float32)

Expand Down

0 comments on commit ed2a176

Please sign in to comment.