In [26]:
import random, torch, torch.nn.functional as F
from torch.nn import Linear, Sequential, BatchNorm1d, ReLU
from torch_geometric.nn import GINConv, global_add_pool
from torch_geometric.data import Data
from gnn_dataset_generation import PolymerDataset

class GIN(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim):
        super().__init__()
        def mlp(f): return Sequential(Linear(f, hidden_dim), BatchNorm1d(hidden_dim), ReLU(),
                                      Linear(hidden_dim, hidden_dim), ReLU())
        self.conv1, self.conv2, self.conv3 = GINConv(mlp(in_dim)), GINConv(mlp(hidden_dim)), GINConv(mlp(hidden_dim))
        self.lin1, self.lin2 = Linear(hidden_dim, hidden_dim), Linear(hidden_dim, 1)
    def forward(self, d):
        h = self.conv1(d.x, d.edge_index).relu()
        h = self.conv2(h, d.edge_index).relu()
        h = self.conv3(h, d.edge_index)
        h = global_add_pool(h, d.batch)
        h = self.lin1(h).relu()
        h = F.dropout(h, p=0.5, training=self.training)
        return self.lin2(h).view(-1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = PolymerDataset(root='.')
with torch.no_grad():
    y_raw  = torch.stack([d.y for d in dataset])[:, 0]
mu, sigma = y_raw.mean().item(), y_raw.std().item()
model = GIN(dataset[0].x.size(1), 128).to(device)
model.load_state_dict(torch.load('best_gin_area.pt', map_location=device))
model.eval()

def rand_gene(): return random.randint(-10, 10)
def token(n): return f"S{n}" if n > 0 else f"E{abs(n)}"

def encode(seq):
    features, edge_i, edge_j, back = [], [], [], {}
    for idx, v in enumerate(seq):
        features.append([1, 0, 0]); back[idx + 1] = idx
    for i in range(len(seq) - 1):
        edge_i += [i, i + 1]; edge_j += [i + 1, i]
    next_node = len(features)
    for pos, v in enumerate(seq, 1):
        if v == 0: continue
        m = abs(v); bt = 1 if v > 0 else 2
        prev = back[pos]
        for _ in range(m):
            features.append([0, bt == 1, bt == 2])
            edge_i += [prev, next_node]; edge_j += [next_node, prev]
            prev = next_node; next_node += 1
    x = torch.tensor(features, dtype=torch.float)
    edge_index = torch.tensor([edge_i, edge_j], dtype=torch.long)
    data = Data(x=x, edge_index=edge_index)
    data.batch = torch.zeros(data.num_nodes, dtype=torch.long)
    return data.to(device)

@torch.no_grad()
def fitness(seq):
    z = model(encode(seq)).item()
    return z * sigma + mu

def mutate(seq):
    seq = [rand_gene() if random.random() < 0.1 else v for v in seq]
    if random.random() < 0.05 and len(seq) < 20: seq.insert(random.randrange(len(seq) + 1), rand_gene())
    if random.random() < 0.05 and len(seq) > 10: seq.pop(random.randrange(len(seq)))
    return seq

def crossover(a, b):
    c1, c2 = random.randrange(1, len(a)), random.randrange(1, len(b))
    child = a[:c1] + b[c2:]
    if len(child) > 20: child = child[:20]
    if len(child) < 10: child += [rand_gene() for _ in range(10 - len(child))]
    return child

population = [[rand_gene() for _ in range(random.randint(10, 20))] for _ in range(100)]
for _ in range(50):
    population.sort(key=fitness, reverse=False)
    elites = population[:10]
    next_pop = elites[:]
    while len(next_pop) < 100:
        if random.random() < 0.8:
            p1, p2 = random.sample(population[:50], 2); child = crossover(p1, p2)
        else:
            child = population[random.randrange(50)].copy()
        next_pop.append(mutate(child))
    population = next_pop

best = min(population, key=fitness)
print("Best vector:", best)
print("Predicted area:", fitness(best))
print("Tokens:", [token(v) for v in best])

Best vector: [-1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
Predicted area: -1054.6394931082614
Tokens: ['E1', 'E0', 'E1', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'S1', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0', 'E0']
