In [1]:
import torch
import torch.nn as nn
from torch.nn import init
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(0)

<torch._C.Generator at 0x2936d979d10>

In [5]:
with open('AeCa.txt', 'r', encoding='UTF-8') as f:
    text=f.read()
f.close()

In [6]:
chars=sorted(list(set(text)))
vocab_size=len(chars)
print(chars,vocab_size)

['A', 'C', 'G', 'T'] 4


In [7]:
stoi={s:i for i,s in enumerate(chars)}
itos={i:s for i,s in enumerate(stoi)}
stoi,itos

({'A': 0, 'C': 1, 'G': 2, 'T': 3}, {0: 'A', 1: 'C', 2: 'G', 3: 'T'})

In [8]:
encode=lambda word:[stoi[s] for s in word]
decode=lambda num: ''.join([itos[i] for i in num])
data=torch.tensor(encode(text),dtype=torch.long)
data

tensor([2, 1, 1,  ..., 2, 2, 1])

In [9]:
data[0]

tensor(2)

In [10]:
def cut_sequence(data, window_size):
  X,Y = [],[]
  for i in range(0,len(data)-window_size):
    X.append(data[i:i+window_size])
    Y.append(data[i+window_size])
  return X,Y

window_size = 64
X_list, Y_list = cut_sequence(data, window_size)

X = torch.stack(X_list)
Y = torch.tensor(Y_list)

print(X.shape)
print(Y.shape)

torch.Size([1590985, 64])
torch.Size([1590985])


In [11]:
import numpy as np

one_hot_mapping = {
    0:[1,0,0,0],
    1:[0,1,0,0],
    2:[0,0,1,0],
    3:[0,0,0,1]
}
def OneHot(arr,one_hot_mapping):
  if arr.ndim == 2:
    result = np.array([[one_hot_mapping[val.item()] for val in row] for row in arr])
    return result
  else:
    result = np.array([one_hot_mapping[val.item()] for val in arr])
    return result

In [13]:
import numpy as np

def optimized_OneHot(arr, one_hot_mapping):
  output_dim = len(one_hot_mapping[0])
  flat_arr = arr.flatten()
  one_hot = np.zeros((flat_arr.size, output_dim), dtype=np.float32)
  one_hot[np.arange(flat_arr.size), flat_arr] = 1
  if arr.ndim == 2:
    return one_hot.reshape((*arr.shape, output_dim))
  else:
    return one_hot

In [15]:
X = torch.tensor(optimized_OneHot(X.numpy(),one_hot_mapping))
Y = torch.tensor(optimized_OneHot(Y.numpy(),one_hot_mapping))

In [16]:
print(X.shape)
print(Y.shape)

torch.Size([1590985, 64, 4])
torch.Size([1590985, 4])


In [17]:
import torch.nn.functional as F
class DNACoder(nn.Module):
    def __init__(self, input_dim=4, window_size=64, hidden_dim=256, num_heads=8, fc_dim=128):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=64, kernel_size=3)
        self.bn1 = nn.BatchNorm1d(num_features=64)
        self.pool1 = nn.MaxPool1d(kernel_size=2)

        self.conv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(num_features=128)
        self.pool2 = nn.MaxPool1d(kernel_size=2)

        self.lstm = nn.LSTM(input_size=128, hidden_size=hidden_dim, batch_first=True, bidirectional=True)

        self.attn = nn.MultiheadAttention(embed_dim=hidden_dim*2, num_heads=num_heads, batch_first=True)

        self.dropout = nn.Dropout(0.3)
        self.fc1 = nn.Linear(hidden_dim*2, fc_dim)
        self.fc_out = nn.Linear(fc_dim, 4)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B, 4, seq_len)
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = x.permute(0, 2, 1)  # (B, seq_len, 128)

        lstm_out, _ = self.lstm(x)
        attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out)
        x = self.dropout(F.relu(self.fc1(attn_out[:, -1, :])))
        out = self.fc_out(x)
        return out


In [18]:
from torch.utils.data import TensorDataset, DataLoader

dataset = TensorDataset(X, Y.argmax(dim=1))

# Create a DataLoader (loads in small batches)
loader = DataLoader(dataset, batch_size=256, shuffle=True, pin_memory=True)

In [19]:
for xb, yb in loader:
    print(xb.shape)
    print(yb.shape)
    break

torch.Size([256, 64, 4])
torch.Size([256])


In [20]:
#Just testing for my convenience
model = DNACoder(input_dim=4, window_size=window_size)
outputs = model(xb.float())
print(outputs.shape)
print(outputs)
print(yb.shape)
print(yb)
preds = torch.argmax(outputs, dim=1)
print(preds.shape)

torch.Size([256, 4])
tensor([[ 0.0501,  0.0093,  0.0207, -0.0636],
        [ 0.0664, -0.0013,  0.0329, -0.0765],
        [ 0.0625,  0.0025,  0.0222, -0.0670],
        ...,
        [ 0.0497, -0.0208,  0.0086, -0.0436],
        [ 0.0487, -0.0303,  0.0391, -0.0576],
        [ 0.0688, -0.0139,  0.0259, -0.0668]], grad_fn=<AddmmBackward0>)
torch.Size([256])
tensor([2, 1, 3, 2, 1, 1, 2, 0, 0, 1, 2, 2, 2, 2, 0, 1, 1, 0, 1, 1, 3, 2, 3, 1,
        2, 3, 2, 1, 1, 1, 3, 1, 3, 2, 3, 0, 3, 2, 3, 2, 1, 3, 1, 2, 1, 2, 0, 0,
        0, 3, 1, 2, 0, 1, 3, 2, 1, 1, 0, 2, 1, 2, 2, 1, 1, 1, 2, 0, 1, 2, 2, 0,
        2, 3, 0, 2, 2, 3, 1, 1, 0, 2, 1, 2, 1, 0, 0, 1, 2, 1, 2, 2, 2, 1, 2, 1,
        1, 3, 1, 3, 2, 2, 3, 2, 2, 2, 2, 3, 2, 3, 0, 3, 0, 2, 2, 0, 1, 1, 0, 0,
        0, 2, 2, 3, 2, 3, 2, 0, 3, 3, 0, 3, 1, 2, 0, 2, 3, 1, 3, 2, 0, 1, 0, 3,
        2, 0, 1, 3, 1, 2, 2, 1, 1, 3, 2, 0, 0, 1, 1, 2, 0, 0, 1, 0, 1, 0, 0, 2,
        0, 3, 2, 3, 0, 2, 3, 1, 2, 0, 1, 1, 3, 2, 0, 3, 1, 1, 3, 2, 3, 2, 1, 3,
     

In [21]:
model = DNACoder(input_dim=4, window_size=window_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(40):
    total_loss = 0.0
    correct = 0
    total = 0

    model.train()
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()

        outputs = model(xb.float())
        loss = criterion(outputs, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = torch.argmax(outputs, dim=1)   # predicted class index
        correct += (preds == yb).sum().item()
        total += yb.size(0)

    epoch_loss = total_loss / len(loader)
    epoch_acc = correct / total * 100
    if epoch % 5 == 0:
        print(f"Epoch [{epoch+1}/40], Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")

KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), 'model.pth')

In [None]:
loaded_model = DNACoder(input_dim=4, window_size=window_size).to(device)
loaded_model.load_state_dict(torch.load('model.pth'))
loaded_model.eval()

print("Model loaded successfully!")

Model loaded successfully!
