diff --git a/pom.xml b/pom.xml
index 20ac713..b96555f 100644
--- a/pom.xml
+++ b/pom.xml
@@ -67,6 +67,13 @@
provided
+
+ org.keycloak
+ keycloak-model-jpa
+ ${keycloak.version}
+ provided
+
+
org.junit.jupiter
junit-jupiter-engine
@@ -114,5 +121,10 @@
1.18.26
compile
+
+ jakarta.persistence
+ jakarta.persistence-api
+ 3.1.0
+
diff --git a/src/main/java/org/keycloak/social/weixin/WeiXinIdentityProvider.java b/src/main/java/org/keycloak/social/weixin/WeiXinIdentityProvider.java
index 3d3a0a6..5aaf3c3 100644
--- a/src/main/java/org/keycloak/social/weixin/WeiXinIdentityProvider.java
+++ b/src/main/java/org/keycloak/social/weixin/WeiXinIdentityProvider.java
@@ -244,13 +244,14 @@ protected UriBuilder createAuthorizationUrl(AuthenticationRequest request) {
var wechatApi = new WechatMpApi(
config.getConfig().get(WECHAT_MP_APP_ID),
config.getConfig().get(WECHAT_MP_APP_SECRET),
- session
+ session,
+ request.getAuthenticationSession()
);
- var ticketUrl = wechatApi.createTmpQrCode(new TicketRequest(2592000, "QR_STR_SCENE", new ActionInfo(new Scene("1")))).url;
- logger.info("ticketUrl = " + ticketUrl);
+ var ticket = wechatApi.createTmpQrCode(new TicketRequest(2592000, "QR_STR_SCENE", new ActionInfo(new Scene("1")))).ticket;
+ logger.info("ticket = " + ticket);
- uriBuilder.queryParam("ticket-url", ticketUrl);
+ uriBuilder.queryParam("ticket", ticket).queryParam("qr-code-url", "https://mp.weixin.qq.com/cgi-bin/showqrcode?ticket=" + ticket);
}
} else {
uriBuilder = UriBuilder.fromUri(config.getAuthorizationUrl());
diff --git a/src/main/java/org/keycloak/social/weixin/cache/TicketEntity.java b/src/main/java/org/keycloak/social/weixin/cache/TicketEntity.java
new file mode 100644
index 0000000..3eb9a4c
--- /dev/null
+++ b/src/main/java/org/keycloak/social/weixin/cache/TicketEntity.java
@@ -0,0 +1,32 @@
+package org.keycloak.social.weixin.cache;
+
+import jakarta.persistence.Id;
+import jakarta.persistence.NamedQueries;
+import jakarta.persistence.NamedQuery;
+import lombok.Getter;
+import lombok.Setter;
+import jakarta.persistence.Entity;
+
+@NamedQueries({
+ @NamedQuery(name = "TicketEntity.findById", query = "select t from TicketEntity t where t.id = :id"),
+ @NamedQuery(name = "TicketEntity.findByTicket", query = "select t from TicketEntity t where t.ticket = :ticket"),
+})
+@Getter
+@Entity
+public class TicketEntity {
+ @Setter
+ @Id
+ private String id;
+ @Setter
+ private String ticket;
+ @Setter
+ private String status;
+ @Setter
+ private Number expireSeconds;
+ @Setter
+ private Number ticketCreatedAt;
+ @Setter
+ private Number scannedAt;
+ @Setter
+ private String openid;
+}
diff --git a/src/main/java/org/keycloak/social/weixin/cache/TicketStatusProvider.java b/src/main/java/org/keycloak/social/weixin/cache/TicketStatusProvider.java
new file mode 100644
index 0000000..0e6c3f9
--- /dev/null
+++ b/src/main/java/org/keycloak/social/weixin/cache/TicketStatusProvider.java
@@ -0,0 +1,338 @@
+package org.keycloak.social.weixin.cache;
+
+import jakarta.persistence.*;
+import jakarta.persistence.criteria.CriteriaBuilder;
+import jakarta.persistence.criteria.CriteriaDelete;
+import jakarta.persistence.criteria.CriteriaQuery;
+import jakarta.persistence.criteria.CriteriaUpdate;
+import jakarta.persistence.metamodel.Metamodel;
+import org.jboss.logging.Logger;
+import org.keycloak.component.ComponentModel;
+import org.keycloak.models.KeycloakSession;
+import org.keycloak.storage.UserStorageProvider;
+import org.keycloak.connections.jpa.JpaConnectionProvider;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.UUID;
+import java.util.concurrent.ConcurrentHashMap;
+
+public class TicketStatusProvider implements UserStorageProvider {
+ private final KeycloakSession session;
+ private final ComponentModel model;
+ private static Map localCache = new ConcurrentHashMap<>();
+
+ private static final Logger logger = Logger.getLogger(TicketStatusProvider.class);
+
+ protected EntityManager em;
+
+ public TicketStatusProvider(KeycloakSession keycloakSession, ComponentModel componentModel) {
+ this.session = keycloakSession;
+ this.model = componentModel;
+ var jpaProvider = session.getProvider(JpaConnectionProvider.class, "ticket-store");
+
+ if (jpaProvider != null) {
+ this.em = jpaProvider.getEntityManager();
+ return;
+ }
+ logger.warn("em is null");
+
+ this.em = new EntityManager() {
+ @Override
+ public void persist(Object o) {
+ localCache.put(((TicketEntity) o).getTicket(), (TicketEntity) o);
+ }
+
+ @Override
+ public T merge(T t) {
+ return null;
+ }
+
+ @Override
+ public void remove(Object o) {
+ localCache.remove(((TicketEntity) o).getTicket());
+ }
+
+ @Override
+ public T find(Class aClass, Object o) {
+ return (T) localCache.get((String) o);
+ }
+
+ @Override
+ public T find(Class aClass, Object o, Map map) {
+ return null;
+ }
+
+ @Override
+ public T find(Class aClass, Object o, LockModeType lockModeType) {
+ return null;
+ }
+
+ @Override
+ public T find(Class aClass, Object o, LockModeType lockModeType, Map map) {
+ return null;
+ }
+
+ @Override
+ public T getReference(Class aClass, Object o) {
+ return (T) o;
+ }
+
+ @Override
+ public void flush() {
+ }
+
+ @Override
+ public void setFlushMode(FlushModeType flushModeType) {
+
+ }
+
+ @Override
+ public FlushModeType getFlushMode() {
+ return null;
+ }
+
+ @Override
+ public void lock(Object o, LockModeType lockModeType) {
+
+ }
+
+ @Override
+ public void lock(Object o, LockModeType lockModeType, Map map) {
+
+ }
+
+ @Override
+ public void refresh(Object o) {
+
+ }
+
+ @Override
+ public void refresh(Object o, Map map) {
+
+ }
+
+ @Override
+ public void refresh(Object o, LockModeType lockModeType) {
+
+ }
+
+ @Override
+ public void refresh(Object o, LockModeType lockModeType, Map map) {
+
+ }
+
+ @Override
+ public void clear() {
+ localCache.clear();
+ }
+
+ @Override
+ public void detach(Object o) {
+
+ }
+
+ @Override
+ public boolean contains(Object o) {
+ return false;
+ }
+
+ @Override
+ public LockModeType getLockMode(Object o) {
+ return null;
+ }
+
+ @Override
+ public void setProperty(String s, Object o) {
+
+ }
+
+ @Override
+ public Map getProperties() {
+ return null;
+ }
+
+ @Override
+ public Query createQuery(String s) {
+ return null;
+ }
+
+ @Override
+ public TypedQuery createQuery(CriteriaQuery criteriaQuery) {
+ return null;
+ }
+
+ @Override
+ public Query createQuery(CriteriaUpdate criteriaUpdate) {
+ return null;
+ }
+
+ @Override
+ public Query createQuery(CriteriaDelete criteriaDelete) {
+ return null;
+ }
+
+ @Override
+ public TypedQuery createQuery(String s, Class aClass) {
+ return null;
+ }
+
+ @Override
+ public Query createNamedQuery(String s) {
+ return null;
+ }
+
+ @Override
+ public TypedQuery createNamedQuery(String s, Class aClass) {
+ return null;
+ }
+
+ @Override
+ public Query createNativeQuery(String s) {
+ return null;
+ }
+
+ @Override
+ public Query createNativeQuery(String s, Class aClass) {
+ return null;
+ }
+
+ @Override
+ public Query createNativeQuery(String s, String s1) {
+ return null;
+ }
+
+ @Override
+ public StoredProcedureQuery createNamedStoredProcedureQuery(String s) {
+ return null;
+ }
+
+ @Override
+ public StoredProcedureQuery createStoredProcedureQuery(String s) {
+ return null;
+ }
+
+ @Override
+ public StoredProcedureQuery createStoredProcedureQuery(String s, Class... classes) {
+ return null;
+ }
+
+ @Override
+ public StoredProcedureQuery createStoredProcedureQuery(String s, String... strings) {
+ return null;
+ }
+
+ @Override
+ public void joinTransaction() {
+
+ }
+
+ @Override
+ public boolean isJoinedToTransaction() {
+ return false;
+ }
+
+ @Override
+ public T unwrap(Class aClass) {
+ return null;
+ }
+
+ @Override
+ public Object getDelegate() {
+ return null;
+ }
+
+ @Override
+ public void close() {
+
+ }
+
+ @Override
+ public boolean isOpen() {
+ return false;
+ }
+
+ @Override
+ public EntityTransaction getTransaction() {
+ return null;
+ }
+
+ @Override
+ public EntityManagerFactory getEntityManagerFactory() {
+ return null;
+ }
+
+ @Override
+ public CriteriaBuilder getCriteriaBuilder() {
+ return null;
+ }
+
+ @Override
+ public Metamodel getMetamodel() {
+ return null;
+ }
+
+ @Override
+ public EntityGraph createEntityGraph(Class aClass) {
+ return null;
+ }
+
+ @Override
+ public EntityGraph> createEntityGraph(String s) {
+ return null;
+ }
+
+ @Override
+ public EntityGraph> getEntityGraph(String s) {
+ return null;
+ }
+
+ @Override
+ public List> getEntityGraphs(Class aClass) {
+ return null;
+ }
+ };
+ }
+
+ @Override
+ public void close() {
+
+ }
+
+ public TicketEntity saveTicketStatus(String ticket, Number expireSeconds, String status) {
+ logger.info(String.format("saveTicketStatus by %s%n%s%n", ticket, expireSeconds, status));
+
+ var entity = new TicketEntity();
+ entity.setId(UUID.randomUUID().toString());
+ entity.setTicket(ticket);
+ entity.setStatus(status);
+ entity.setExpireSeconds(expireSeconds);
+ entity.setTicketCreatedAt(System.currentTimeMillis() / 1000L);
+ em.persist(entity);
+
+ return entity;
+ }
+
+ public TicketEntity getTicketStatus(String ticket) {
+ logger.info(String.format("getTicketStatus by %s%n", ticket));
+
+ var ticketEntity = em.find(TicketEntity.class, ticket);
+
+ logger.info(String.format("ticketEntity is %s%n", ticketEntity));
+
+ return ticketEntity;
+ }
+
+ public TicketEntity saveTicketStatus(TicketEntity ticket) {
+ logger.info(String.format("saveTicketStatus by %s%n", ticket));
+
+ if (Objects.equals(ticket.getStatus(), "expired")) {
+ em.remove(ticket);
+ } else {
+ em.persist(ticket);
+ }
+
+ return ticket;
+ }
+}
diff --git a/src/main/java/org/keycloak/social/weixin/cache/TicketStatusProviderFactory.java b/src/main/java/org/keycloak/social/weixin/cache/TicketStatusProviderFactory.java
new file mode 100644
index 0000000..6679997
--- /dev/null
+++ b/src/main/java/org/keycloak/social/weixin/cache/TicketStatusProviderFactory.java
@@ -0,0 +1,19 @@
+package org.keycloak.social.weixin.cache;
+
+import org.keycloak.component.ComponentModel;
+import org.keycloak.models.KeycloakSession;
+import org.keycloak.storage.UserStorageProviderFactory;
+
+import java.util.Properties;
+
+public class TicketStatusProviderFactory implements UserStorageProviderFactory {
+ @Override
+ public TicketStatusProvider create(KeycloakSession keycloakSession, ComponentModel componentModel) {
+ return new TicketStatusProvider(keycloakSession, componentModel);
+ }
+
+ @Override
+ public String getId() {
+ return "TicketStatusProvider";
+ }
+}
diff --git a/src/main/java/org/keycloak/social/weixin/egress/wechat/mp/WechatMpApi.java b/src/main/java/org/keycloak/social/weixin/egress/wechat/mp/WechatMpApi.java
index 8fa8c40..357c8d6 100644
--- a/src/main/java/org/keycloak/social/weixin/egress/wechat/mp/WechatMpApi.java
+++ b/src/main/java/org/keycloak/social/weixin/egress/wechat/mp/WechatMpApi.java
@@ -4,6 +4,8 @@
import org.jboss.logging.Logger;
import org.keycloak.broker.provider.util.SimpleHttp;
import org.keycloak.models.KeycloakSession;
+import org.keycloak.sessions.AuthenticationSessionModel;
+import org.keycloak.social.weixin.cache.TicketStatusProvider;
import org.keycloak.social.weixin.egress.wechat.mp.models.AccessTokenResponse;
import org.keycloak.social.weixin.egress.wechat.mp.models.TicketRequest;
import org.keycloak.social.weixin.egress.wechat.mp.models.TicketResponse;
@@ -14,11 +16,13 @@ public class WechatMpApi {
private final String appSecret;
private final String appId;
protected final KeycloakSession session;
+ protected final AuthenticationSessionModel authenticationSession;
- public WechatMpApi(String appId, String appSecret, KeycloakSession session) {
+ public WechatMpApi(String appId, String appSecret, KeycloakSession session, AuthenticationSessionModel authenticationSession) {
this.appId = appId;
this.appSecret = appSecret;
this.session = session;
+ this.authenticationSession = authenticationSession;
}
@SneakyThrows
@@ -43,6 +47,16 @@ public TicketResponse createTmpQrCode(TicketRequest ticketRequest) {
logger.info(String.format("res is %s%n", res));
+ this.saveTicketStatus(res.ticket, res.expire_seconds);
+
return res;
}
+
+ private void saveTicketStatus(String ticket, Number expireSeconds) {
+ logger.info(String.format("saveTicketStatus by %s%n%s%n", ticket, expireSeconds));
+
+ var ticketStatusProvider = new TicketStatusProvider(session, null);
+
+ ticketStatusProvider.saveTicketStatus(ticket, expireSeconds, "not_scanned");
+ }
}
diff --git a/src/main/java/org/keycloak/social/weixin/resources/QrCodeResourceProvider.java b/src/main/java/org/keycloak/social/weixin/resources/QrCodeResourceProvider.java
index b40f167..0a1cde6 100644
--- a/src/main/java/org/keycloak/social/weixin/resources/QrCodeResourceProvider.java
+++ b/src/main/java/org/keycloak/social/weixin/resources/QrCodeResourceProvider.java
@@ -1,26 +1,31 @@
package org.keycloak.social.weixin.resources;
-import jakarta.ws.rs.GET;
-import jakarta.ws.rs.Path;
-import jakarta.ws.rs.Produces;
-import jakarta.ws.rs.QueryParam;
+import jakarta.ws.rs.*;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;
-import lombok.RequiredArgsConstructor;
+import lombok.SneakyThrows;
import org.jboss.logging.Logger;
import org.keycloak.models.KeycloakSession;
import org.keycloak.services.resource.RealmResourceProvider;
-import org.keycloak.social.weixin.egress.wechat.mp.WechatMpApi;
-import org.keycloak.social.weixin.egress.wechat.mp.models.ActionInfo;
-import org.keycloak.social.weixin.egress.wechat.mp.models.Scene;
-import org.keycloak.social.weixin.egress.wechat.mp.models.TicketRequest;
+import org.keycloak.social.weixin.cache.TicketStatusProvider;
+import org.w3c.dom.Document;
+import org.xml.sax.InputSource;
+import javax.xml.parsers.DocumentBuilder;
+import javax.xml.parsers.DocumentBuilderFactory;
+import java.io.StringReader;
import java.util.Map;
+import java.util.Objects;
-@RequiredArgsConstructor
public class QrCodeResourceProvider implements RealmResourceProvider {
private final KeycloakSession session;
protected static final Logger logger = Logger.getLogger(QrCodeResourceProvider.class);
+ private final TicketStatusProvider ticketStatusProvider;
+
+ public QrCodeResourceProvider(KeycloakSession session) {
+ this.session = session;
+ this.ticketStatusProvider = new TicketStatusProvider(session, null);
+ }
@Override
public Object getResource() {
@@ -43,7 +48,7 @@ public Response helloAnonymous() {
@GET
@Path("mp-qr")
@Produces(MediaType.TEXT_HTML)
- public Response mpQrUrl(@QueryParam("ticket-url") String ticketUrl) {
+ public Response mpQrUrl(@QueryParam("ticket") String ticket, @QueryParam("qr-code-url") String qrCodeUrl) {
logger.info("展示一个 HTML 页面,该页面使用 React 展示一个组件,它调用一个后端 API,得到一个带参二维码 URL,并将该 URL 作为 img 的 src 属性值");
String htmlContent = "\n" +
@@ -53,7 +58,7 @@ public Response mpQrUrl(@QueryParam("ticket-url") String ticketUrl) {
"\n" +
"\n" +
" \n" +
- "
\n" +
+ "
\n" +
"
\n" +
"\n" +
"